• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h"  // NOLINT
26 
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
29 #include "tensorflow/cc/framework/gradients.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/cc/framework/scope_internal.h"
32 #include "tensorflow/cc/ops/while_loop.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/framework/logging.h"
36 #include "tensorflow/core/framework/op_gen_lib.h"
37 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 #include "tensorflow/c/c_api_internal.h"
39 #include "tensorflow/c/tf_status_internal.h"
40 #include "tensorflow/c/tf_tensor.h"
41 #include "tensorflow/core/common_runtime/device_mgr.h"
42 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/common_runtime/shape_refiner.h"
45 #include "tensorflow/core/framework/allocation_description.pb.h"
46 #include "tensorflow/core/framework/kernel_def.pb.h"
47 #include "tensorflow/core/framework/log_memory.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/partial_tensor_shape.h"
51 #include "tensorflow/core/framework/tensor.h"
52 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/framework/versions.pb.h"
57 #include "tensorflow/core/graph/graph.h"
58 #include "tensorflow/core/graph/node_builder.h"
59 #include "tensorflow/core/graph/validate.h"
60 #include "tensorflow/core/lib/gtl/array_slice.h"
61 #include "tensorflow/core/platform/coding.h"
62 #include "tensorflow/core/platform/errors.h"
63 #include "tensorflow/core/platform/mem.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/protobuf.h"
66 #include "tensorflow/core/platform/status.h"
67 #include "tensorflow/core/platform/str_util.h"
68 #include "tensorflow/core/platform/strcat.h"
69 #include "tensorflow/core/platform/stringpiece.h"
70 #include "tensorflow/core/platform/thread_annotations.h"
71 #include "tensorflow/core/platform/types.h"
72 #include "tensorflow/core/public/session.h"
73 #include "tensorflow/core/public/version.h"
74 
75 // The implementation below is at the top level instead of the
76 // brain namespace because we are defining 'extern "C"' functions.
77 using tensorflow::AllocationDescription;
78 using tensorflow::DataType;
79 using tensorflow::ExtendSessionGraphHelper;
80 using tensorflow::Graph;
81 using tensorflow::GraphDef;
82 using tensorflow::mutex_lock;
83 using tensorflow::NameRangeMap;
84 using tensorflow::NameRangesForNode;
85 using tensorflow::NewSession;
86 using tensorflow::Node;
87 using tensorflow::NodeBuilder;
88 using tensorflow::NodeDef;
89 using tensorflow::OpDef;
90 using tensorflow::OpRegistry;
91 using tensorflow::OutputTensor;
92 using tensorflow::PartialTensorShape;
93 using tensorflow::RunMetadata;
94 using tensorflow::RunOptions;
95 using tensorflow::Session;
96 using tensorflow::Status;
97 using tensorflow::string;
98 using tensorflow::Tensor;
99 using tensorflow::TensorBuffer;
100 using tensorflow::TensorId;
101 using tensorflow::TensorShape;
102 using tensorflow::TensorShapeProto;
103 using tensorflow::VersionDef;
104 using tensorflow::errors::FailedPrecondition;
105 using tensorflow::errors::InvalidArgument;
106 using tensorflow::gtl::ArraySlice;
107 using tensorflow::strings::StrCat;
108 
109 extern "C" {
110 
111 // --------------------------------------------------------------------------
TF_Version()112 const char* TF_Version() { return TF_VERSION_STRING; }
113 
114 // --------------------------------------------------------------------------
115 
116 // --------------------------------------------------------------------------
TF_NewSessionOptions()117 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)118 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
119 
TF_SetTarget(TF_SessionOptions * options,const char * target)120 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
121   options->options.target = target;
122 }
123 
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)124 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
125                   size_t proto_len, TF_Status* status) {
126   if (!options->options.config.ParseFromArray(proto, proto_len)) {
127     status->status = InvalidArgument("Unparseable ConfigProto");
128   }
129 }
130 // --------------------------------------------------------------------------
TF_NewBuffer()131 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
132 
TF_NewBufferFromString(const void * proto,size_t proto_len)133 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
134   void* copy = tensorflow::port::Malloc(proto_len);
135   memcpy(copy, proto, proto_len);
136 
137   TF_Buffer* buf = new TF_Buffer;
138   buf->data = copy;
139   buf->length = proto_len;
140   buf->data_deallocator = [](void* data, size_t length) {
141     tensorflow::port::Free(data);
142   };
143   return buf;
144 }
145 
TF_DeleteBuffer(TF_Buffer * buffer)146 void TF_DeleteBuffer(TF_Buffer* buffer) {
147   if (buffer == nullptr) return;
148   if (buffer->data_deallocator != nullptr) {
149     (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
150                                 buffer->length);
151   }
152   delete buffer;
153 }
154 
TF_GetBuffer(TF_Buffer * buffer)155 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
156 
157 // --------------------------------------------------------------------------
158 
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)159 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
160                                               TF_Status* status) {
161   Session* session;
162   status->status = NewSession(opt->options, &session);
163   if (status->status.ok()) {
164     return new TF_DeprecatedSession({session});
165   } else {
166     DCHECK_EQ(nullptr, session);
167     return nullptr;
168   }
169 }
170 
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)171 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
172   status->status = s->session->Close();
173 }
174 
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)175 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
176   status->status = Status::OK();
177   if (s == nullptr) return;
178   delete s->session;
179   delete s;
180 }
181 
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)182 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
183                     size_t proto_len, TF_Status* status) {
184   GraphDef g;
185   if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
186     status->status = InvalidArgument("Invalid GraphDef");
187     return;
188   }
189   status->status = s->session->Extend(g);
190 }
191 
192 }  // end extern "C"
193 
194 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)195 static void TF_Reset_Helper(const TF_SessionOptions* opt,
196                             const char** containers, int ncontainers,
197                             TF_Status* status) {
198   std::vector<string> container_names(ncontainers);
199   for (int i = 0; i < ncontainers; ++i) {
200     container_names[i] = containers[i];
201   }
202 
203   status->status = Reset(opt->options, container_names);
204 }
205 
206 extern "C" {
207 
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)208 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
209               int ncontainers, TF_Status* status) {
210   TF_Reset_Helper(opt, containers, ncontainers, status);
211 }
212 
213 }  // end extern "C"
214 
215 namespace tensorflow {
216 
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)217 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
218                        TF_Buffer* out) {
219   if (out->data != nullptr) {
220     return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
221   }
222   const size_t proto_size = in.ByteSizeLong();
223   void* buf = port::Malloc(proto_size);
224   if (buf == nullptr) {
225     return tensorflow::errors::ResourceExhausted(
226         "Failed to allocate memory to serialize message of type '",
227         in.GetTypeName(), "' and size ", proto_size);
228   }
229   if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
230     port::Free(buf);
231     return InvalidArgument("Unable to serialize ", in.GetTypeName(),
232                            " protocol buffer, perhaps the serialized size (",
233                            proto_size, " bytes) is too large?");
234   }
235   out->data = buf;
236   out->length = proto_size;
237   out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
238   return Status::OK();
239 }
240 
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)241 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
242                     const char* mutation_type) {
243   // If any session has already run this node_id, mark this session as
244   // unrunnable.
245   for (auto it : graph->sessions) {
246     mutex_lock session_lock(it.first->mu);
247     if (it.first->last_num_graph_nodes > op.node.id()) {
248       it.second = strings::StrCat(
249           "Operation '", op.node.DebugString(), "' was changed by ",
250           mutation_type,
251           " after it was run by a session. This mutation will have no effect, "
252           "and will trigger an error in the future. Either don't modify "
253           "nodes after running them or create a new session.");
254     }
255   }
256 }
257 
258 namespace {
259 
260 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)261 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
262     tensorflow::shape_inference::InferenceContext* ic, int num_dims,
263     const int64_t* dims) {
264   if (num_dims != -1) {
265     std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
266     dim_vec.reserve(num_dims);
267     for (int i = 0; i < num_dims; ++i) {
268       dim_vec.push_back(ic->MakeDim(dims[i]));
269     }
270     return ic->MakeShape(dim_vec);
271   } else {
272     return ic->UnknownShape();
273   }
274 }
275 
276 }  // namespace
277 
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)278 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
279                                            int num_shapes_and_types,
280                                            const int64_t** shapes,
281                                            const int* ranks,
282                                            const TF_DataType* types,
283                                            TF_Status* status) {
284   Node* node = &output.oper->node;
285 
286   mutex_lock l(graph->mu);
287   tensorflow::shape_inference::InferenceContext* ic =
288       graph->refiner.GetContext(node);
289   if (ic == nullptr) {
290     status->status =
291         InvalidArgument("Node ", node->name(), " was not found in the graph");
292     return;
293   }
294 
295   auto shape_and_type_vec =
296       std::vector<tensorflow::shape_inference::ShapeAndType>(
297           num_shapes_and_types);
298   for (int i = 0; i < num_shapes_and_types; ++i) {
299     tensorflow::shape_inference::ShapeHandle shape_handle =
300         ShapeHandleFromDims(ic, ranks[i], shapes[i]);
301     shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
302         shape_handle, static_cast<DataType>(types[i]));
303   }
304 
305   ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
306 }
307 
308 // Helpers for loading a TensorFlow plugin (a .so file).
309 Status LoadDynamicLibrary(const char* library_filename, void** result,
310                           const void** buf, size_t* len);
311 
312 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
313 // directly, instead of requiring us to serialize to a GraphDef and
314 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)315 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
316   if (session->graph != nullptr) {
317     // Take the graph lock before the session lock to avoid deadlock. This is
318     // safe since session->graph does not change.
319     session->graph->mu.lock();
320     mutex_lock session_lock(session->mu);
321     const Graph& graph = session->graph->graph;
322 
323     const string& mutation_warning = session->graph->sessions[session];
324     if (!mutation_warning.empty()) {
325       // TODO(b/74949947): turn this back into an error status
326       LOG(WARNING) << mutation_warning;
327       session->graph->sessions[session].clear();
328     }
329 
330     const auto num_nodes = graph.num_node_ids();
331     if (session->last_num_graph_nodes < num_nodes) {
332       // TODO(nolivia): check this on a subset of the graph instead of all of
333       // it.
334       status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
335       if (!status->status.ok()) {
336         session->graph->mu.unlock();
337         return false;
338       }
339 
340       GraphDef graph_def;
341       *graph_def.mutable_versions() = graph.versions();
342       // Fill graph_def with nodes with ids in the range
343       // [session->last_num_graph_nodes, num_nodes), that is the nodes
344       // added since the last TF_SessionRun() call.
345       for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
346         Node* const node = graph.FindNodeId(id);
347         if (node != nullptr && node->IsOp()) {
348           NodeDef* const node_def = graph_def.add_node();
349           *node_def = node->def();
350         }
351       }
352       *graph_def.mutable_library() = graph.flib_def().ToProto();
353       session->graph->mu.unlock();
354       status->status = session->session->Extend(std::move(graph_def));
355       if (!status->status.ok()) {
356         // Contract is we always delete input_values[i].
357         return false;
358       }
359       // Note: session->session is not modified if Extend() fails, so
360       // we only set last_num_graph_nodes if it succeeds.
361       session->last_num_graph_nodes = num_nodes;
362     } else {
363       session->graph->mu.unlock();
364     }
365   }
366   return true;
367 }
368 
369 }  // namespace tensorflow
370 
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)371 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
372                          TF_Status* status) {
373   status->status = Status::OK();
374   for (int i = 0; i < noutputs; ++i) {
375     c_outputs[i] = nullptr;
376   }
377 }
378 
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)379 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
380                           std::vector<std::pair<string, Tensor>>* input_pairs,
381                           TF_Status* status) {
382   const int ninputs = input_pairs->size();
383   for (int i = 0; i < ninputs; ++i) {
384     status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
385     if (!status->status.ok()) return false;
386   }
387   return true;
388 }
389 
390 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
391 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)392 static TF_Tensor* EmptyTensor(TF_DataType dtype,
393                               const tensorflow::TensorShape& shape) {
394   static char empty;
395   tensorflow::int64 nelems = 1;
396   std::vector<tensorflow::int64> dims;
397   for (int i = 0; i < shape.dims(); ++i) {
398     dims.push_back(shape.dim_size(i));
399     nelems *= shape.dim_size(i);
400   }
401   CHECK_EQ(nelems, 0);
402   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
403                 "64-bit int types should match in size");
404   return TF_NewTensor(
405       dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
406       reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
407 }
408 
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)409 static void TF_Run_Helper(
410     Session* session, const char* handle, const TF_Buffer* run_options,
411     // Input tensors
412     const std::vector<std::pair<string, Tensor>>& input_pairs,
413     // Output tensors
414     const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
415     // Target nodes
416     const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
417     TF_Status* status) {
418   const int noutputs = output_tensor_names.size();
419   std::vector<Tensor> outputs(noutputs);
420   Status result;
421 
422   if (handle == nullptr) {
423     RunOptions run_options_proto;
424     if (run_options != nullptr && !run_options_proto.ParseFromArray(
425                                       run_options->data, run_options->length)) {
426       status->status = InvalidArgument("Unparseable RunOptions proto");
427       return;
428     }
429     if (run_metadata != nullptr && run_metadata->data != nullptr) {
430       status->status =
431           InvalidArgument("Passing non-empty run_metadata is invalid.");
432       return;
433     }
434 
435     RunMetadata run_metadata_proto;
436     result = session->Run(run_options_proto, input_pairs, output_tensor_names,
437                           target_oper_names, &outputs, &run_metadata_proto);
438 
439     // Serialize back to upstream client, who now owns the new buffer
440     if (run_metadata != nullptr) {
441       status->status = MessageToBuffer(run_metadata_proto, run_metadata);
442       if (!status->status.ok()) return;
443     }
444   } else {
445     // NOTE(zongheng): PRun does not support RunOptions yet.
446     result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
447   }
448   if (!result.ok()) {
449     status->status = result;
450     return;
451   }
452 
453   // Store results in c_outputs[]
454   for (int i = 0; i < noutputs; ++i) {
455     const Tensor& src = outputs[i];
456     if (!src.IsInitialized() || src.NumElements() == 0) {
457       c_outputs[i] =
458           EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
459       continue;
460     }
461     c_outputs[i] = TF_TensorFromTensor(src, &status->status);
462     if (!status->status.ok()) return;
463   }
464 }
465 
466 extern "C" {
467 
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)468 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
469             // Input tensors
470             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
471             // Output tensors
472             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
473             // Target nodes
474             const char** c_target_oper_names, int ntargets,
475             TF_Buffer* run_metadata, TF_Status* status) {
476   TF_Run_Setup(noutputs, c_outputs, status);
477   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
478   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
479   for (int i = 0; i < ninputs; ++i) {
480     input_pairs[i].first = c_input_names[i];
481   }
482   std::vector<string> output_names(noutputs);
483   for (int i = 0; i < noutputs; ++i) {
484     output_names[i] = c_output_names[i];
485   }
486   std::vector<string> target_oper_names(ntargets);
487   for (int i = 0; i < ntargets; ++i) {
488     target_oper_names[i] = c_target_oper_names[i];
489   }
490   TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
491                 c_outputs, target_oper_names, run_metadata, status);
492 }
493 
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)494 void TF_PRunSetup(TF_DeprecatedSession* s,
495                   // Input names
496                   const char** c_input_names, int ninputs,
497                   // Output names
498                   const char** c_output_names, int noutputs,
499                   // Target nodes
500                   const char** c_target_oper_names, int ntargets,
501                   const char** handle, TF_Status* status) {
502   *handle = nullptr;
503 
504   std::vector<string> input_names(ninputs);
505   std::vector<string> output_names(noutputs);
506   std::vector<string> target_oper_names(ntargets);
507   for (int i = 0; i < ninputs; ++i) {
508     input_names[i] = c_input_names[i];
509   }
510   for (int i = 0; i < noutputs; ++i) {
511     output_names[i] = c_output_names[i];
512   }
513   for (int i = 0; i < ntargets; ++i) {
514     target_oper_names[i] = c_target_oper_names[i];
515   }
516   string new_handle;
517   status->status = s->session->PRunSetup(input_names, output_names,
518                                          target_oper_names, &new_handle);
519   if (status->status.ok()) {
520     char* buf = new char[new_handle.size() + 1];
521     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
522     *handle = buf;
523   }
524 }
525 
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)526 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
527              // Input tensors
528              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
529              // Output tensors
530              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
531              // Target nodes
532              const char** c_target_oper_names, int ntargets,
533              TF_Status* status) {
534   TF_Run_Setup(noutputs, c_outputs, status);
535   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
536   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
537   for (int i = 0; i < ninputs; ++i) {
538     input_pairs[i].first = c_input_names[i];
539   }
540 
541   std::vector<string> output_names(noutputs);
542   for (int i = 0; i < noutputs; ++i) {
543     output_names[i] = c_output_names[i];
544   }
545   std::vector<string> target_oper_names(ntargets);
546   for (int i = 0; i < ntargets; ++i) {
547     target_oper_names[i] = c_target_oper_names[i];
548   }
549   TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
550                 c_outputs, target_oper_names, nullptr, status);
551 }
552 
TF_LoadLibrary(const char * library_filename,TF_Status * status)553 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
554   TF_Library* lib_handle = new TF_Library;
555   status->status = tensorflow::LoadDynamicLibrary(
556       library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
557       &lib_handle->op_list.length);
558   if (!status->status.ok()) {
559     delete lib_handle;
560     return nullptr;
561   }
562   return lib_handle;
563 }
564 
TF_GetOpList(TF_Library * lib_handle)565 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
566 
TF_DeleteLibraryHandle(TF_Library * lib_handle)567 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
568   if (lib_handle == nullptr) return;
569   tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
570   delete lib_handle;
571 }
572 
TF_GetAllOpList()573 TF_Buffer* TF_GetAllOpList() {
574   std::vector<tensorflow::OpDef> op_defs;
575   tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
576   tensorflow::OpList op_list;
577   for (const auto& op : op_defs) {
578     *(op_list.add_op()) = op;
579   }
580   TF_Buffer* ret = TF_NewBuffer();
581   TF_CHECK_OK(MessageToBuffer(op_list, ret));
582   return ret;
583 }
584 
585 // --------------------------------------------------------------------------
586 // ListDevices & SessionListDevices API
587 
TF_DeleteDeviceList(TF_DeviceList * list)588 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
589 
TF_SessionListDevices(TF_Session * session,TF_Status * status)590 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
591   TF_DeviceList* response = new TF_DeviceList;
592   if (session && session->session)
593     status->status = session->session->ListDevices(&response->response);
594   return response;
595 }
596 
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)597 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
598                                                TF_Status* status) {
599   TF_DeviceList* response = new TF_DeviceList;
600   if (session && session->session)
601     status->status = session->session->ListDevices(&response->response);
602   return response;
603 }
604 
TF_DeviceListCount(const TF_DeviceList * list)605 int TF_DeviceListCount(const TF_DeviceList* list) {
606   return list->response.size();
607 }
608 
609 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
610   return_type method_name(const TF_DeviceList* list, const int index,     \
611                           TF_Status* status) {                            \
612     if (list == nullptr) {                                                \
613       status->status = InvalidArgument("list is null!");                  \
614       return err_val;                                                     \
615     }                                                                     \
616     if (index < 0 || index >= list->response.size()) {                    \
617       status->status = InvalidArgument("index out of bounds");            \
618       return err_val;                                                     \
619     }                                                                     \
620     status->status = Status::OK();                                        \
621     return list->response[index].accessor;                                \
622   }
623 
624 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
625 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
626                      nullptr);
627 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
628 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
629 
630 #undef TF_DEVICELIST_METHOD
631 
632 }  // end extern "C"
633 
634 // --------------------------------------------------------------------------
635 // New Graph and Session API
636 
637 // Helper functions -----------------------------------------------------------
638 
639 namespace {
640 
ToOperation(Node * node)641 TF_Operation* ToOperation(Node* node) {
642   return static_cast<TF_Operation*>(static_cast<void*>(node));
643 }
644 
OutputName(const TF_Output & output)645 string OutputName(const TF_Output& output) {
646   return StrCat(output.oper->node.name(), ":", output.index);
647 }
648 
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)649 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
650                                           const char* attr_name,
651                                           TF_Status* status) {
652   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
653   if (attr == nullptr) {
654     status->status = InvalidArgument("Operation '", oper->node.name(),
655                                      "' has no attr named '", attr_name, "'.");
656   }
657   return attr;
658 }
659 
ToTensorId(const TF_Output & output)660 TensorId ToTensorId(const TF_Output& output) {
661   return TensorId(output.oper->node.name(), output.index);
662 }
663 
664 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)665 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
666                                                      int n) {
667   std::vector<tensorflow::Output> outputs(n);
668   for (int i = 0; i < n; ++i) {
669     outputs[i] =
670         tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
671   }
672   return outputs;
673 }
674 
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)675 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
676                           TF_Output* tf_outputs) {
677   for (int i = 0; i < outputs.size(); i++) {
678     tf_outputs[i].oper = ToOperation(outputs[i].node());
679     tf_outputs[i].index = outputs[i].index();
680   }
681 }
682 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
683 
684 }  // namespace
685 
686 // Shape functions -----------------------------------------------------------
687 
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)688 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
689                             const int64_t* dims, const int num_dims,
690                             TF_Status* status) {
691   Node* node = &output.oper->node;
692 
693   mutex_lock l(graph->mu);
694   tensorflow::shape_inference::InferenceContext* ic =
695       graph->refiner.GetContext(node);
696   if (ic == nullptr) {
697     status->status =
698         InvalidArgument("Node ", node->name(), " was not found in the graph");
699     return;
700   }
701   tensorflow::shape_inference::ShapeHandle new_shape =
702       tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
703   status->status = graph->refiner.SetShape(node, output.index, new_shape);
704 }
705 
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)706 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
707                              TF_Status* status) {
708   Node* node = &output.oper->node;
709 
710   mutex_lock l(graph->mu);
711   tensorflow::shape_inference::InferenceContext* ic =
712       graph->refiner.GetContext(node);
713   if (ic == nullptr) {
714     status->status =
715         InvalidArgument("Node ", node->name(), " was not found in the graph");
716     return -1;
717   }
718 
719   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
720 
721   // Unknown rank means the number of dimensions is -1.
722   if (!ic->RankKnown(shape)) {
723     return -1;
724   }
725 
726   return ic->Rank(shape);
727 }
728 
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)729 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
730                             int num_dims, TF_Status* status) {
731   Node* node = &output.oper->node;
732 
733   mutex_lock l(graph->mu);
734   tensorflow::shape_inference::InferenceContext* ic =
735       graph->refiner.GetContext(node);
736   if (ic == nullptr) {
737     status->status =
738         InvalidArgument("Node ", node->name(), " was not found in the graph");
739     return;
740   }
741 
742   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
743 
744   int rank = -1;
745   if (ic->RankKnown(shape)) {
746     rank = ic->Rank(shape);
747   }
748 
749   if (num_dims != rank) {
750     status->status = InvalidArgument("Expected rank is ", num_dims,
751                                      " but actual rank is ", rank);
752     return;
753   }
754 
755   if (num_dims == 0) {
756     // Output shape is a scalar.
757     return;
758   }
759 
760   // Rank is greater than 0, so fill in the values, if known, and
761   // -1 for unknown values.
762   for (int i = 0; i < num_dims; ++i) {
763     auto dim = ic->Dim(shape, i);
764     tensorflow::int64 value = -1;
765     if (ic->ValueKnown(dim)) {
766       value = ic->Value(dim);
767     }
768     dims[i] = value;
769   }
770 }
771 
772 // TF_OperationDescription functions ------------------------------------------
773 
774 extern "C" {
775 
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)776 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
777                                                       const char* op_type,
778                                                       const char* oper_name)
779     TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
780   return new TF_OperationDescription(graph, op_type, oper_name);
781 }
782 
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)783 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
784                                          const char* oper_name) {
785   mutex_lock l(graph->mu);
786   return TF_NewOperationLocked(graph, op_type, oper_name);
787 }
788 
TF_SetDevice(TF_OperationDescription * desc,const char * device)789 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
790   desc->node_builder.Device(device);
791 }
792 
TF_AddInput(TF_OperationDescription * desc,TF_Output input)793 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
794   desc->node_builder.Input(&input.oper->node, input.index);
795 }
796 
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)797 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
798                      int num_inputs) {
799   std::vector<NodeBuilder::NodeOut> input_list;
800   input_list.reserve(num_inputs);
801   for (int i = 0; i < num_inputs; ++i) {
802     input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
803   }
804   desc->node_builder.Input(input_list);
805 }
806 
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)807 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
808   desc->node_builder.ControlInput(&input->node);
809 }
810 
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)811 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
812   desc->colocation_constraints.emplace(
813       StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
814 }
815 
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)816 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
817                       const void* value, size_t length) {
818   tensorflow::StringPiece s(static_cast<const char*>(value), length);
819   desc->node_builder.Attr(attr_name, s);
820 }
821 
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)822 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
823                           const void* const* values, const size_t* lengths,
824                           int num_values) {
825   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
826     desc->colocation_constraints.clear();
827     for (int i = 0; i < num_values; ++i) {
828       desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
829                                            lengths[i]);
830     }
831   } else {
832     std::vector<tensorflow::StringPiece> v;
833     v.reserve(num_values);
834     for (int i = 0; i < num_values; ++i) {
835       v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
836     }
837     desc->node_builder.Attr(attr_name, v);
838   }
839 }
840 
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)841 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
842                    int64_t value) {
843   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
844                 "64-bit int types should match in size");
845   desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
846 }
847 
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)848 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
849                        const int64_t* values, int num_values) {
850   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
851                 "64-bit int types should match in size");
852   desc->node_builder.Attr(
853       attr_name,
854       ArraySlice<const tensorflow::int64>(
855           reinterpret_cast<const tensorflow::int64*>(values), num_values));
856 }
857 
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)858 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
859                      float value) {
860   desc->node_builder.Attr(attr_name, value);
861 }
862 
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)863 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
864                          const float* values, int num_values) {
865   desc->node_builder.Attr(attr_name,
866                           ArraySlice<const float>(values, num_values));
867 }
868 
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)869 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
870                     unsigned char value) {
871   desc->node_builder.Attr(attr_name, static_cast<bool>(value));
872 }
873 
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)874 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
875                         const unsigned char* values, int num_values) {
876   std::unique_ptr<bool[]> b(new bool[num_values]);
877   for (int i = 0; i < num_values; ++i) {
878     b[i] = values[i];
879   }
880   desc->node_builder.Attr(attr_name,
881                           ArraySlice<const bool>(b.get(), num_values));
882 }
883 
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)884 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
885                     TF_DataType value) {
886   desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
887 }
888 
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)889 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
890                         const TF_DataType* values, int num_values) {
891   desc->node_builder.Attr(
892       attr_name, ArraySlice<const DataType>(
893                      reinterpret_cast<const DataType*>(values), num_values));
894 }
895 
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)896 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
897                            const char* placeholder) {
898   tensorflow::AttrValue attr_value;
899   attr_value.set_placeholder(placeholder);
900   desc->node_builder.Attr(attr_name, attr_value);
901 }
902 
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)903 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
904                         const char* value, size_t length) {
905   tensorflow::NameAttrList func_name;
906   func_name.set_name(string(value, value + length));
907   desc->node_builder.Attr(attr_name, func_name);
908 }
909 
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)910 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
911                      const int64_t* dims, int num_dims) {
912   PartialTensorShape shape;
913   if (num_dims >= 0) {
914     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
915                   "64-bit int types should match in size");
916     shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
917         reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
918   }
919   desc->node_builder.Attr(attr_name, shape);
920 }
921 
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)922 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
923                          const int64_t* const* dims, const int* num_dims,
924                          int num_shapes) {
925   std::vector<PartialTensorShape> shapes;
926   shapes.reserve(num_shapes);
927   for (int i = 0; i < num_shapes; ++i) {
928     if (num_dims[i] < 0) {
929       shapes.emplace_back();
930     } else {
931       static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
932                     "64-bit int types should match in size");
933       shapes.emplace_back(ArraySlice<tensorflow::int64>(
934           reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
935     }
936   }
937   desc->node_builder.Attr(attr_name, shapes);
938 }
939 
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)940 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
941                                 const char* attr_name, const void* proto,
942                                 size_t proto_len, TF_Status* status) {
943   // shape.ParseFromArray takes an int as length, this function takes size_t,
944   // make sure there is no information loss.
945   if (proto_len > std::numeric_limits<int>::max()) {
946     status->status = InvalidArgument(
947         "proto_len (", proto_len,
948         " bytes) is too large to be parsed by the protocol buffer library");
949     return;
950   }
951   TensorShapeProto shape;
952   if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
953     desc->node_builder.Attr(attr_name, shape);
954     status->status = Status::OK();
955   } else {
956     status->status = InvalidArgument("Unparseable TensorShapeProto");
957   }
958 }
959 
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)960 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
961                                     const char* attr_name,
962                                     const void* const* protos,
963                                     const size_t* proto_lens, int num_shapes,
964                                     TF_Status* status) {
965   std::vector<TensorShapeProto> shapes;
966   shapes.resize(num_shapes);
967   for (int i = 0; i < num_shapes; ++i) {
968     if (proto_lens[i] > std::numeric_limits<int>::max()) {
969       status->status = InvalidArgument(
970           "length of element ", i, " in the list (", proto_lens[i],
971           " bytes) is too large to be parsed by the protocol buffer library");
972       return;
973     }
974     if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
975       status->status =
976           InvalidArgument("Unparseable TensorShapeProto at index ", i);
977       return;
978     }
979   }
980   desc->node_builder.Attr(attr_name, shapes);
981   status->status = Status::OK();
982 }
983 
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)984 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
985                       TF_Tensor* value, TF_Status* status) {
986   Tensor t;
987   status->status = TF_TensorToTensor(value, &t);
988   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
989 }
990 
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)991 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
992                           TF_Tensor* const* values, int num_values,
993                           TF_Status* status) {
994   status->status = Status::OK();
995   std::vector<Tensor> t;
996   t.reserve(num_values);
997 
998   for (int i = 0; i < num_values && status->status.ok(); ++i) {
999     Tensor v;
1000     status->status = TF_TensorToTensor(values[i], &v);
1001     t.emplace_back(v);
1002   }
1003 
1004   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1005 }
1006 
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1007 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1008                           const void* proto, size_t proto_len,
1009                           TF_Status* status) {
1010   tensorflow::AttrValue attr_value;
1011   if (!attr_value.ParseFromArray(proto, proto_len)) {
1012     status->status = InvalidArgument("Unparseable AttrValue proto");
1013     return;
1014   }
1015 
1016   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1017     if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1018         attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1019       status->status =
1020           InvalidArgument("Expected \"list\" field for \"",
1021                           tensorflow::kColocationAttrName, "\" attribute");
1022       return;
1023     }
1024     desc->colocation_constraints.clear();
1025     for (const string& location : attr_value.list().s()) {
1026       desc->colocation_constraints.insert(location);
1027     }
1028   } else {
1029     desc->node_builder.Attr(attr_name, std::move(attr_value));
1030   }
1031 
1032   status->status = Status::OK();
1033 }
1034 
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1035 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1036                                               TF_Status* status)
1037     TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1038   Node* ret = nullptr;
1039 
1040   if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1041     status->status = InvalidArgument("Duplicate node name in graph: '",
1042                                      desc->node_builder.node_name(), "'");
1043   } else {
1044     if (!desc->colocation_constraints.empty()) {
1045       desc->node_builder.Attr(
1046           tensorflow::kColocationAttrName,
1047           std::vector<string>(desc->colocation_constraints.begin(),
1048                               desc->colocation_constraints.end()));
1049     }
1050     status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1051                                                  /*consume=*/true);
1052 
1053     if (status->status.ok()) {
1054       // Run shape inference function for newly added node.
1055       status->status = desc->graph->refiner.AddNode(ret);
1056     }
1057     if (status->status.ok()) {
1058       // Add the node to the name-to-node mapping.
1059       desc->graph->name_map[ret->name()] = ret;
1060     } else if (ret != nullptr) {
1061       desc->graph->graph.RemoveNode(ret);
1062       ret = nullptr;
1063     }
1064   }
1065 
1066   delete desc;
1067 
1068   return ToOperation(ret);
1069 }
1070 
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1071 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1072                                  TF_Status* status) {
1073   mutex_lock l(desc->graph->mu);
1074   return TF_FinishOperationLocked(desc, status);
1075 }
1076 
1077 // TF_Operation functions
1078 // ----------------------------------------------------------
1079 
TF_OperationName(TF_Operation * oper)1080 const char* TF_OperationName(TF_Operation* oper) {
1081   return oper->node.name().c_str();
1082 }
1083 
TF_OperationOpType(TF_Operation * oper)1084 const char* TF_OperationOpType(TF_Operation* oper) {
1085   return oper->node.type_string().c_str();
1086 }
1087 
TF_OperationDevice(TF_Operation * oper)1088 const char* TF_OperationDevice(TF_Operation* oper) {
1089   return oper->node.requested_device().c_str();
1090 }
1091 
TF_OperationNumOutputs(TF_Operation * oper)1092 int TF_OperationNumOutputs(TF_Operation* oper) {
1093   return oper->node.num_outputs();
1094 }
1095 
TF_OperationOutputType(TF_Output oper_out)1096 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1097   return static_cast<TF_DataType>(
1098       oper_out.oper->node.output_type(oper_out.index));
1099 }
1100 
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1101 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1102                                  TF_Status* status) {
1103   NameRangeMap name_ranges;
1104   status->status =
1105       NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1106   if (!status->status.ok()) return -1;
1107   auto iter = name_ranges.find(arg_name);
1108   if (iter == name_ranges.end()) {
1109     status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1110     return -1;
1111   }
1112   return iter->second.second - iter->second.first;
1113 }
1114 
TF_OperationNumInputs(TF_Operation * oper)1115 int TF_OperationNumInputs(TF_Operation* oper) {
1116   return oper->node.num_inputs();
1117 }
1118 
TF_OperationInputType(TF_Input oper_in)1119 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1120   return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1121 }
1122 
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1123 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1124                                 TF_Status* status) {
1125   NameRangeMap name_ranges;
1126   status->status =
1127       NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1128   if (!status->status.ok()) return -1;
1129   auto iter = name_ranges.find(arg_name);
1130   if (iter == name_ranges.end()) {
1131     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1132     return -1;
1133   }
1134   return iter->second.second - iter->second.first;
1135 }
1136 
TF_OperationInput(TF_Input oper_in)1137 TF_Output TF_OperationInput(TF_Input oper_in) {
1138   const tensorflow::Edge* edge;
1139   Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1140   if (!s.ok()) {
1141     return {nullptr, -1};
1142   }
1143 
1144   return {ToOperation(edge->src()), edge->src_output()};
1145 }
1146 
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1147 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1148                            int max_inputs) {
1149   for (auto* edge : oper->node.in_edges()) {
1150     if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1151       inputs[edge->dst_input()] = {ToOperation(edge->src()),
1152                                    edge->src_output()};
1153     }
1154   }
1155 }
1156 
TF_OperationOutputNumConsumers(TF_Output oper_out)1157 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1158   int count = 0;
1159   for (const auto* edge : oper_out.oper->node.out_edges()) {
1160     if (edge->src_output() == oper_out.index) {
1161       ++count;
1162     }
1163   }
1164   return count;
1165 }
1166 
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1167 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1168                                 int max_consumers) {
1169   int count = 0;
1170   for (const auto* edge : oper_out.oper->node.out_edges()) {
1171     if (edge->src_output() == oper_out.index) {
1172       if (count < max_consumers) {
1173         consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1174       }
1175       ++count;
1176     }
1177   }
1178   return count;
1179 }
1180 
TF_OperationNumControlInputs(TF_Operation * oper)1181 int TF_OperationNumControlInputs(TF_Operation* oper) {
1182   int count = 0;
1183   for (const auto* edge : oper->node.in_edges()) {
1184     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1185       ++count;
1186     }
1187   }
1188   return count;
1189 }
1190 
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1191 int TF_OperationGetControlInputs(TF_Operation* oper,
1192                                  TF_Operation** control_inputs,
1193                                  int max_control_inputs) {
1194   int count = 0;
1195   for (const auto* edge : oper->node.in_edges()) {
1196     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1197       if (count < max_control_inputs) {
1198         control_inputs[count] = ToOperation(edge->src());
1199       }
1200       ++count;
1201     }
1202   }
1203   return count;
1204 }
1205 
TF_OperationNumControlOutputs(TF_Operation * oper)1206 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1207   int count = 0;
1208   for (const auto* edge : oper->node.out_edges()) {
1209     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1210       ++count;
1211     }
1212   }
1213   return count;
1214 }
1215 
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1216 int TF_OperationGetControlOutputs(TF_Operation* oper,
1217                                   TF_Operation** control_outputs,
1218                                   int max_control_outputs) {
1219   int count = 0;
1220   for (const auto* edge : oper->node.out_edges()) {
1221     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1222       if (count < max_control_outputs) {
1223         control_outputs[count] = ToOperation(edge->dst());
1224       }
1225       ++count;
1226     }
1227   }
1228   return count;
1229 }
1230 
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1231 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1232                                             const char* attr_name,
1233                                             TF_Status* status) {
1234   TF_AttrMetadata metadata;
1235   const auto* attr = GetAttrValue(oper, attr_name, status);
1236   if (!status->status.ok()) return metadata;
1237   switch (attr->value_case()) {
1238 #define SINGLE_CASE(kK, attr_type, size_expr) \
1239   case tensorflow::AttrValue::kK:             \
1240     metadata.is_list = 0;                     \
1241     metadata.list_size = -1;                  \
1242     metadata.type = attr_type;                \
1243     metadata.total_size = size_expr;          \
1244     break;
1245 
1246     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1247     SINGLE_CASE(kI, TF_ATTR_INT, -1);
1248     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1249     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1250     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1251     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1252                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1253     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1254 #undef SINGLE_CASE
1255 
1256     case tensorflow::AttrValue::kList:
1257       metadata.is_list = 1;
1258       metadata.list_size = 0;
1259       metadata.total_size = -1;
1260 #define LIST_CASE(field, attr_type, ...)              \
1261   if (attr->list().field##_size() > 0) {              \
1262     metadata.type = attr_type;                        \
1263     metadata.list_size = attr->list().field##_size(); \
1264     __VA_ARGS__;                                      \
1265     break;                                            \
1266   }
1267 
1268       LIST_CASE(
1269           s, TF_ATTR_STRING, metadata.total_size = 0;
1270           for (int i = 0; i < attr->list().s_size();
1271                ++i) { metadata.total_size += attr->list().s(i).size(); });
1272       LIST_CASE(i, TF_ATTR_INT);
1273       LIST_CASE(f, TF_ATTR_FLOAT);
1274       LIST_CASE(b, TF_ATTR_BOOL);
1275       LIST_CASE(type, TF_ATTR_TYPE);
1276       LIST_CASE(
1277           shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1278           for (int i = 0; i < attr->list().shape_size(); ++i) {
1279             const auto& s = attr->list().shape(i);
1280             metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1281           });
1282       LIST_CASE(tensor, TF_ATTR_TENSOR);
1283       LIST_CASE(tensor, TF_ATTR_FUNC);
1284 #undef LIST_CASE
1285       // All lists empty, determine the type from the OpDef.
1286       if (metadata.list_size == 0) {
1287         for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1288           const auto& a = oper->node.op_def().attr(i);
1289           if (a.name() != attr_name) continue;
1290           const string& typestr = a.type();
1291           if (typestr == "list(string)") {
1292             metadata.type = TF_ATTR_STRING;
1293           } else if (typestr == "list(int)") {
1294             metadata.type = TF_ATTR_INT;
1295           } else if (typestr == "list(float)") {
1296             metadata.type = TF_ATTR_FLOAT;
1297           } else if (typestr == "list(bool)") {
1298             metadata.type = TF_ATTR_BOOL;
1299           } else if (typestr == "list(type)") {
1300             metadata.type = TF_ATTR_TYPE;
1301           } else if (typestr == "list(shape)") {
1302             metadata.type = TF_ATTR_SHAPE;
1303           } else if (typestr == "list(tensor)") {
1304             metadata.type = TF_ATTR_TENSOR;
1305           } else if (typestr == "list(func)") {
1306             metadata.type = TF_ATTR_FUNC;
1307           } else {
1308             status->status = InvalidArgument(
1309                 "Attribute '", attr_name,
1310                 "' has an empty value of an unrecognized type '", typestr, "'");
1311             return metadata;
1312           }
1313         }
1314       }
1315       break;
1316 
1317     case tensorflow::AttrValue::kPlaceholder:
1318       metadata.is_list = 0;
1319       metadata.list_size = -1;
1320       metadata.type = TF_ATTR_PLACEHOLDER;
1321       metadata.total_size = -1;
1322       break;
1323 
1324     case tensorflow::AttrValue::kFunc:
1325       metadata.is_list = 0;
1326       metadata.list_size = -1;
1327       metadata.type = TF_ATTR_FUNC;
1328       metadata.total_size = -1;
1329       break;
1330 
1331     case tensorflow::AttrValue::VALUE_NOT_SET:
1332       status->status =
1333           InvalidArgument("Attribute '", attr_name, "' has no value set");
1334       break;
1335   }
1336   return metadata;
1337 }
1338 
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1339 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1340                                void* value, size_t max_length,
1341                                TF_Status* status) {
1342   const auto* attr = GetAttrValue(oper, attr_name, status);
1343   if (!status->status.ok()) return;
1344   if (attr->value_case() != tensorflow::AttrValue::kS) {
1345     status->status =
1346         InvalidArgument("Attribute '", attr_name, "' is not a string");
1347     return;
1348   }
1349   if (max_length <= 0) {
1350     return;
1351   }
1352   const auto& s = attr->s();
1353   std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1354 }
1355 
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1356 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1357                                    void** values, size_t* lengths,
1358                                    int max_values, void* storage,
1359                                    size_t storage_size, TF_Status* status) {
1360   const auto* attr = GetAttrValue(oper, attr_name, status);
1361   if (!status->status.ok()) return;
1362   if (attr->value_case() != tensorflow::AttrValue::kList) {
1363     status->status =
1364         InvalidArgument("Value for '", attr_name, "' is not a list");
1365     return;
1366   }
1367   const auto len = std::min(max_values, attr->list().s_size());
1368   char* p = static_cast<char*>(storage);
1369   for (int i = 0; i < len; ++i) {
1370     const string& s = attr->list().s(i);
1371     values[i] = p;
1372     lengths[i] = s.size();
1373     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1374       status->status = InvalidArgument(
1375           "Not enough storage to hold the requested list of strings");
1376       return;
1377     }
1378     memcpy(values[i], s.data(), s.size());
1379     p += s.size();
1380   }
1381 }
1382 
1383 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1384   void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1385             TF_Status* status) {                                             \
1386     cpp_type v;                                                              \
1387     status->status =                                                         \
1388         tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1389     if (!status->status.ok()) return;                                        \
1390     *value = static_cast<c_type>(v);                                         \
1391   }                                                                          \
1392   void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1393                   int max_values, TF_Status* status) {                       \
1394     const auto* attr = GetAttrValue(oper, attr_name, status);                \
1395     if (!status->status.ok()) return;                                        \
1396     if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1397       status->status =                                                       \
1398           InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1399       return;                                                                \
1400     }                                                                        \
1401     const auto len = std::min(max_values, attr->list().list_field##_size()); \
1402     for (int i = 0; i < len; ++i) {                                          \
1403       values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1404     }                                                                        \
1405   }
1406 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1407 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1408 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1409 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1410 #undef DEFINE_GETATTR
1411 
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1412 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1413                               int64_t* value, int num_dims, TF_Status* status) {
1414   PartialTensorShape shape;
1415   status->status =
1416       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1417   if (!status->status.ok()) return;
1418   auto len = std::min(shape.dims(), num_dims);
1419   for (int i = 0; i < len; ++i) {
1420     value[i] = shape.dim_size(i);
1421   }
1422 }
1423 
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1424 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1425                                   int64_t** dims, int* num_dims, int num_shapes,
1426                                   int64_t* storage, int storage_size,
1427                                   TF_Status* status) {
1428   std::vector<PartialTensorShape> shapes;
1429   status->status =
1430       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1431   if (!status->status.ok()) return;
1432   auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1433   int64_t* p = storage;
1434   int storage_left = storage_size;
1435   for (int i = 0; i < len; ++i) {
1436     // shapes[i].dims() == -1 for shapes with an unknown rank.
1437     int64_t n = shapes[i].dims();
1438     num_dims[i] = n;
1439     dims[i] = p;
1440     if (n < 0) {
1441       continue;
1442     }
1443     if (storage_left < n) {
1444       status->status = InvalidArgument(
1445           "Not enough storage to hold the requested list of shapes");
1446       return;
1447     }
1448     storage_left -= n;
1449     for (int j = 0; j < n; ++j, ++p) {
1450       *p = shapes[i].dim_size(j);
1451     }
1452   }
1453 }
1454 
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1455 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1456                                          const char* attr_name,
1457                                          TF_Buffer* value, TF_Status* status) {
1458   const auto* attr = GetAttrValue(oper, attr_name, status);
1459   if (!status->status.ok()) return;
1460   if (attr->value_case() != tensorflow::AttrValue::kShape) {
1461     status->status =
1462         InvalidArgument("Value for '", attr_name, "' is not a shape.");
1463     return;
1464   }
1465   status->status = MessageToBuffer(attr->shape(), value);
1466 }
1467 
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1468 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1469                                              const char* attr_name,
1470                                              TF_Buffer** values, int max_values,
1471                                              TF_Status* status) {
1472   const auto* attr = GetAttrValue(oper, attr_name, status);
1473   if (!status->status.ok()) return;
1474   if (attr->value_case() != tensorflow::AttrValue::kList) {
1475     status->status =
1476         InvalidArgument("Value for '", attr_name, "' is not a list");
1477     return;
1478   }
1479   const auto len = std::min(max_values, attr->list().shape_size());
1480   for (int i = 0; i < len; ++i) {
1481     values[i] = TF_NewBuffer();
1482     status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1483     if (!status->status.ok()) {
1484       // Delete everything allocated to far, the operation has failed.
1485       for (int j = 0; j <= i; ++j) {
1486         TF_DeleteBuffer(values[j]);
1487       }
1488       return;
1489     }
1490   }
1491 }
1492 
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1493 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1494                                TF_Tensor** value, TF_Status* status) {
1495   *value = nullptr;
1496   Tensor t;
1497   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1498   if (!status->status.ok()) return;
1499   *value = TF_TensorFromTensor(t, &status->status);
1500 }
1501 
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1502 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1503                                    TF_Tensor** values, int max_values,
1504                                    TF_Status* status) {
1505   std::vector<Tensor> ts;
1506   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1507   if (!status->status.ok()) return;
1508   const auto len = std::min(max_values, static_cast<int>(ts.size()));
1509   for (int i = 0; i < len; ++i) {
1510     values[i] = TF_TensorFromTensor(ts[i], &status->status);
1511   }
1512 }
1513 
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1514 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1515                                    TF_Buffer* output_attr_value,
1516                                    TF_Status* status) {
1517   const auto* attr = GetAttrValue(oper, attr_name, status);
1518   if (!status->status.ok()) return;
1519   status->status = MessageToBuffer(*attr, output_attr_value);
1520 }
1521 
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1522 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1523                            TF_Status* status) {
1524   status->status = MessageToBuffer(oper->node.def(), output_node_def);
1525 }
1526 
1527 // TF_Graph functions ---------------------------------------------------------
1528 
TF_Graph()1529 TF_Graph::TF_Graph()
1530     : graph(tensorflow::OpRegistry::Global()),
1531       refiner(graph.versions().producer(), graph.op_registry()),
1532       delete_requested(false),
1533       parent(nullptr),
1534       parent_inputs(nullptr) {
1535   // Tell the shape refiner to also run shape inference on functions.
1536   refiner.set_function_library_for_shape_inference(&graph.flib_def());
1537 }
1538 
TF_NewGraph()1539 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1540 
TF_DeleteGraph(TF_Graph * g)1541 void TF_DeleteGraph(TF_Graph* g) {
1542   if (g == nullptr) return;
1543   g->mu.lock();
1544   g->delete_requested = true;
1545   const bool del = g->sessions.empty();
1546   g->mu.unlock();
1547   if (del) delete g;
1548 }
1549 
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1550 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1551   mutex_lock l(graph->mu);
1552   auto iter = graph->name_map.find(oper_name);
1553   if (iter == graph->name_map.end()) {
1554     return nullptr;
1555   } else {
1556     return ToOperation(iter->second);
1557   }
1558 }
1559 
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1560 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1561   if (*pos == 0) {
1562     // Advance past the first sentinel nodes in every graph (the source & sink).
1563     *pos += 2;
1564   } else {
1565     // Advance to the next node.
1566     *pos += 1;
1567   }
1568 
1569   mutex_lock l(graph->mu);
1570   while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1571     Node* node = graph->graph.FindNodeId(*pos);
1572     // FindNodeId() returns nullptr for nodes that have been deleted.
1573     // We aren't currently allowing nodes to be deleted, but it is safer
1574     // to still check.
1575     if (node != nullptr) return ToOperation(node);
1576     *pos += 1;
1577   }
1578 
1579   // No more nodes.
1580   return nullptr;
1581 }
1582 
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1583 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1584                         TF_Status* status) {
1585   GraphDef def;
1586   {
1587     mutex_lock l(graph->mu);
1588     graph->graph.ToGraphDef(&def);
1589   }
1590   status->status = MessageToBuffer(def, output_graph_def);
1591 }
1592 
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1593 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1594                       TF_Buffer* output_op_def, TF_Status* status) {
1595   const OpDef* op_def;
1596   {
1597     mutex_lock l(graph->mu);
1598     status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1599     if (!status->status.ok()) return;
1600   }
1601   status->status = MessageToBuffer(*op_def, output_op_def);
1602 }
1603 
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1604 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1605                       TF_Status* status) {
1606   VersionDef versions;
1607   {
1608     mutex_lock l(graph->mu);
1609     versions = graph->graph.versions();
1610   }
1611   status->status = MessageToBuffer(versions, output_version_def);
1612 }
1613 
TF_NewImportGraphDefOptions()1614 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1615   return new TF_ImportGraphDefOptions;
1616 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1617 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1618   delete opts;
1619 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1620 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1621                                        const char* prefix) {
1622   opts->opts.prefix = prefix;
1623 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1624 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1625                                               const char* device) {
1626   opts->opts.default_device = device;
1627 }
1628 
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1629 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1630                                               unsigned char uniquify_names) {
1631   opts->opts.uniquify_names = uniquify_names;
1632 }
1633 
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1634 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1635                                                unsigned char uniquify_prefix) {
1636   opts->opts.uniquify_prefix = uniquify_prefix;
1637 }
1638 
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1639 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1640                                              const char* src_name,
1641                                              int src_index, TF_Output dst) {
1642   opts->tensor_id_data.push_back(src_name);
1643   const string& src_name_str = opts->tensor_id_data.back();
1644   // We don't need to store dst's name in tensor_id_data, since `dst` must
1645   // outlive the ImportGraphDef call.
1646   opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1647 }
1648 
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1649 void TF_ImportGraphDefOptionsRemapControlDependency(
1650     TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1651   opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1652       TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1653 }
1654 
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1655 extern void TF_ImportGraphDefOptionsAddControlDependency(
1656     TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1657   opts->opts.control_dependencies.push_back(oper->node.name());
1658 }
1659 
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1660 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1661                                              const char* oper_name, int index) {
1662   opts->tensor_id_data.push_back(oper_name);
1663   const string& oper_name_str = opts->tensor_id_data.back();
1664   opts->opts.return_tensors.emplace_back(oper_name_str, index);
1665 }
1666 
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1667 int TF_ImportGraphDefOptionsNumReturnOutputs(
1668     const TF_ImportGraphDefOptions* opts) {
1669   return opts->opts.return_tensors.size();
1670 }
1671 
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1672 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1673                                                 const char* oper_name) {
1674   opts->opts.return_nodes.push_back(oper_name);
1675 }
1676 
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1677 int TF_ImportGraphDefOptionsNumReturnOperations(
1678     const TF_ImportGraphDefOptions* opts) {
1679   return opts->opts.return_nodes.size();
1680 }
1681 
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1682 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1683                                            int* num_outputs,
1684                                            TF_Output** outputs) {
1685   *num_outputs = results->return_tensors.size();
1686   *outputs = results->return_tensors.data();
1687 }
1688 
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1689 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1690                                               int* num_opers,
1691                                               TF_Operation*** opers) {
1692   *num_opers = results->return_nodes.size();
1693   *opers = results->return_nodes.data();
1694 }
1695 
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1696 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1697     TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1698     const char*** src_names, int** src_indexes) {
1699   *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1700   *src_names = results->missing_unused_key_names.data();
1701   *src_indexes = results->missing_unused_key_indexes.data();
1702 }
1703 
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1704 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1705   delete results;
1706 }
1707 
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1708 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1709                                       const TF_ImportGraphDefOptions* opts,
1710                                       TF_ImportGraphDefResults* tf_results,
1711                                       TF_Status* status)
1712     TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1713   const int last_node_id = graph->graph.num_node_ids();
1714   tensorflow::ImportGraphDefResults results;
1715   status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1716                                               &graph->refiner, &results);
1717   if (!status->status.ok()) return;
1718 
1719   // Add new nodes to name_map
1720   for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1721     auto* node = graph->graph.FindNodeId(i);
1722     if (node != nullptr) graph->name_map[node->name()] = node;
1723   }
1724 
1725   // Populate return_tensors
1726   DCHECK(tf_results->return_tensors.empty());
1727   tf_results->return_tensors.resize(results.return_tensors.size());
1728   for (int i = 0; i < results.return_tensors.size(); ++i) {
1729     tf_results->return_tensors[i].oper =
1730         ToOperation(results.return_tensors[i].first);
1731     tf_results->return_tensors[i].index = results.return_tensors[i].second;
1732   }
1733 
1734   // Populate return_nodes
1735   DCHECK(tf_results->return_nodes.empty());
1736   tf_results->return_nodes.resize(results.return_nodes.size());
1737   for (int i = 0; i < results.return_nodes.size(); ++i) {
1738     tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1739   }
1740 
1741   // Populate missing unused map keys
1742   DCHECK(tf_results->missing_unused_key_names.empty());
1743   DCHECK(tf_results->missing_unused_key_indexes.empty());
1744   DCHECK(tf_results->missing_unused_key_names_data.empty());
1745 
1746   size_t size = results.missing_unused_input_map_keys.size();
1747   tf_results->missing_unused_key_names.resize(size);
1748   tf_results->missing_unused_key_indexes.resize(size);
1749 
1750   for (int i = 0; i < size; ++i) {
1751     TensorId id = results.missing_unused_input_map_keys[i];
1752     tf_results->missing_unused_key_names_data.emplace_back(id.first);
1753     tf_results->missing_unused_key_names[i] =
1754         tf_results->missing_unused_key_names_data.back().c_str();
1755     tf_results->missing_unused_key_indexes[i] = id.second;
1756   }
1757 }
1758 
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1759 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1760     TF_Graph* graph, const TF_Buffer* graph_def,
1761     const TF_ImportGraphDefOptions* options, TF_Status* status) {
1762   GraphDef def;
1763   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1764                                        graph_def->length)) {
1765     status->status = InvalidArgument("Invalid GraphDef");
1766     return nullptr;
1767   }
1768   auto results = new TF_ImportGraphDefResults();
1769   mutex_lock l(graph->mu);
1770   GraphImportGraphDefLocked(graph, def, options, results, status);
1771   if (!status->status.ok()) {
1772     delete results;
1773     return nullptr;
1774   }
1775   return results;
1776 }
1777 
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1778 void TF_GraphImportGraphDefWithReturnOutputs(
1779     TF_Graph* graph, const TF_Buffer* graph_def,
1780     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1781     int num_return_outputs, TF_Status* status) {
1782   if (num_return_outputs != options->opts.return_tensors.size()) {
1783     status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1784                                      options->opts.return_tensors.size(),
1785                                      ", got ", num_return_outputs);
1786     return;
1787   }
1788   if (num_return_outputs > 0 && return_outputs == nullptr) {
1789     status->status = InvalidArgument(
1790         "'return_outputs' must be preallocated to length ", num_return_outputs);
1791     return;
1792   }
1793   GraphDef def;
1794   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1795                                        graph_def->length)) {
1796     status->status = InvalidArgument("Invalid GraphDef");
1797     return;
1798   }
1799   TF_ImportGraphDefResults results;
1800   mutex_lock l(graph->mu);
1801   GraphImportGraphDefLocked(graph, def, options, &results, status);
1802   DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1803   memcpy(return_outputs, results.return_tensors.data(),
1804          num_return_outputs * sizeof(TF_Output));
1805 }
1806 
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1807 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1808                             const TF_ImportGraphDefOptions* options,
1809                             TF_Status* status) {
1810   TF_ImportGraphDefResults* results =
1811       TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1812   TF_DeleteImportGraphDefResults(results);
1813 }
1814 
1815 // While loop functions -------------------------------------------------------
1816 
1817 namespace {
1818 
1819 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1820 
1821 // Creates a placeholder representing an input to the cond or body graph.
1822 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1823 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1824                  TF_Output* input, TF_Status* status) {
1825   TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1826   TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1827   // TODO(skyewm): set placeholder shape
1828   TF_Operation* oper = TF_FinishOperation(desc, status);
1829   if (!status->status.ok()) return false;
1830   *input = {oper, 0};
1831   return true;
1832 }
1833 
1834 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1835 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
1836 // will be prepended to copied node names. `control_deps` are nodes in
1837 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1838 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1839 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1840 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1841                  tensorflow::ShapeRefiner* dst_refiner,
1842                  const TF_Output* src_inputs,
1843                  const std::vector<tensorflow::Output>& dst_inputs,
1844                  const string& prefix,
1845                  const std::vector<tensorflow::Operation>& control_deps,
1846                  const TF_Output* nodes_to_return, int nreturn_nodes,
1847                  std::vector<tensorflow::Output>* return_nodes) {
1848   DCHECK(return_nodes != nullptr);
1849   GraphDef gdef;
1850   src_graph->ToGraphDef(&gdef);
1851 
1852   tensorflow::ImportGraphDefOptions opts;
1853   opts.prefix = prefix;
1854 
1855   for (int i = 0; i < dst_inputs.size(); ++i) {
1856     opts.input_map[ToTensorId(src_inputs[i])] =
1857         TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1858   }
1859   opts.skip_mapped_nodes = true;
1860 
1861   for (const tensorflow::Operation& op : control_deps) {
1862     opts.control_dependencies.push_back(op.node()->name());
1863   }
1864 
1865   for (int i = 0; i < nreturn_nodes; ++i) {
1866     opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1867   }
1868 
1869   // TODO(skyewm): change to OutputTensor
1870   tensorflow::ImportGraphDefResults results;
1871   TF_RETURN_IF_ERROR(
1872       ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1873 
1874   for (const auto& pair : results.return_tensors) {
1875     return_nodes->emplace_back(pair.first, pair.second);
1876   }
1877   return Status::OK();
1878 }
1879 
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1880 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1881   if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1882       params.cond_graph->parent == nullptr ||
1883       params.cond_graph->parent != params.body_graph->parent ||
1884       params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1885       params.ninputs <= 0 || params.cond_inputs == nullptr ||
1886       params.body_inputs == nullptr || params.body_outputs == nullptr) {
1887     s->status = InvalidArgument(
1888         "TF_WhileParams must be created by successful TF_NewWhile() call");
1889     return false;
1890   }
1891   return true;
1892 }
1893 
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1894 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1895   if (params.cond_output.oper == nullptr) {
1896     s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1897     return false;
1898   }
1899   for (int i = 0; i < params.ninputs; ++i) {
1900     if (params.body_outputs[i].oper == nullptr) {
1901       s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1902                                   "field isn't set");
1903       return false;
1904     }
1905   }
1906   if (params.name == nullptr) {
1907     s->status = InvalidArgument("TF_WhileParams `name` field is null");
1908     return false;
1909   }
1910   return true;
1911 }
1912 
1913 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1914 
FreeWhileResources(const TF_WhileParams * params)1915 void FreeWhileResources(const TF_WhileParams* params) {
1916   TF_DeleteGraph(params->cond_graph);
1917   TF_DeleteGraph(params->body_graph);
1918   delete[] params->cond_inputs;
1919   delete[] params->body_inputs;
1920   delete[] params->body_outputs;
1921 }
1922 
EmptyWhileParams()1923 TF_WhileParams EmptyWhileParams() {
1924   return {0,       nullptr, nullptr, {nullptr, 0},
1925           nullptr, nullptr, nullptr, nullptr};
1926 }
1927 
1928 }  // namespace
1929 
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1930 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1931                            TF_Status* status) {
1932 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1933   status->status = tensorflow::errors::Unimplemented(
1934       "Creating while loops is not supported on mobile. File a bug at "
1935       "https://github.com/tensorflow/tensorflow/issues if this feature is "
1936       "important to you");
1937   return EmptyWhileParams();
1938 #else
1939   if (ninputs == 0) {
1940     status->status =
1941         InvalidArgument("TF_NewWhile() must be passed at least one input");
1942     return EmptyWhileParams();
1943   }
1944 
1945   TF_Graph* cond_graph = TF_NewGraph();
1946   TF_Graph* body_graph = TF_NewGraph();
1947   cond_graph->parent = g;
1948   cond_graph->parent_inputs = inputs;
1949   body_graph->parent = g;
1950   body_graph->parent_inputs = inputs;
1951 
1952   TF_Output* cond_inputs = new TF_Output[ninputs];
1953   TF_Output cond_output = {nullptr, -1};
1954   TF_Output* body_inputs = new TF_Output[ninputs];
1955   TF_Output* body_outputs = new TF_Output[ninputs];
1956   for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1957   const char* name = nullptr;
1958 
1959   for (int i = 0; i < ninputs; ++i) {
1960     // TODO(skyewm): prefix names with underscore (requires some plumbing)
1961     if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1962                      &cond_inputs[i], status)) {
1963       break;
1964     }
1965     if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1966                      &body_inputs[i], status)) {
1967       break;
1968     }
1969   }
1970 
1971   TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
1972                            body_graph, body_inputs, body_outputs, name};
1973 
1974   if (!status->status.ok()) {
1975     FreeWhileResources(&params);
1976     return EmptyWhileParams();
1977   }
1978   return params;
1979 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1980 }
1981 
1982 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1983 namespace {
1984 
1985 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)1986 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
1987                           TF_Output* outputs) {
1988   if (!ValidateInputWhileParams(*params, status)) return;
1989 
1990   TF_Graph* parent = params->cond_graph->parent;
1991   TF_Output* parent_inputs = params->cond_graph->parent_inputs;
1992   int num_loop_vars = params->ninputs;
1993 
1994   mutex_lock l(parent->mu);
1995 
1996   // 'cond_fn' copies the cond graph into the parent graph.
1997   tensorflow::ops::CondGraphBuilderFn cond_fn =
1998       [params, parent](const tensorflow::Scope& scope,
1999                        const std::vector<tensorflow::Output>& inputs,
2000                        tensorflow::Output* output) {
2001         DCHECK_EQ(scope.graph(), &parent->graph);
2002         std::vector<tensorflow::Output> cond_output;
2003         TF_RETURN_IF_ERROR(CopyGraph(
2004             &params->cond_graph->graph, &parent->graph, &parent->refiner,
2005             params->cond_inputs, inputs, scope.impl()->name(),
2006             scope.impl()->control_deps(), &params->cond_output,
2007             /* nreturn_nodes */ 1, &cond_output));
2008         *output = cond_output[0];
2009         return Status::OK();
2010       };
2011 
2012   // 'body_fn' copies the body graph into the parent graph.
2013   tensorflow::ops::BodyGraphBuilderFn body_fn =
2014       [params, parent, num_loop_vars](
2015           const tensorflow::Scope& scope,
2016           const std::vector<tensorflow::Output>& inputs,
2017           std::vector<tensorflow::Output>* outputs) {
2018         DCHECK_EQ(scope.graph(), &parent->graph);
2019         TF_RETURN_IF_ERROR(
2020             CopyGraph(&params->body_graph->graph, &parent->graph,
2021                       &parent->refiner, params->body_inputs, inputs,
2022                       scope.impl()->name(), scope.impl()->control_deps(),
2023                       params->body_outputs, num_loop_vars, outputs));
2024         return Status::OK();
2025       };
2026 
2027   // Create the while loop using an internal scope.
2028   tensorflow::Scope scope =
2029       NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2030           .NewSubScope(params->name);
2031 
2032   const int first_new_node_id = parent->graph.num_node_ids();
2033 
2034   tensorflow::OutputList loop_outputs;
2035   status->status = tensorflow::ops::BuildWhileLoop(
2036       scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2037       body_fn, params->name, &loop_outputs);
2038 
2039   // Update name_map with newly-created ops.
2040   // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2041   // a bad status. Once we fix this, we may want to return early instead of
2042   // executing the following code.
2043   for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2044     Node* new_node = parent->graph.FindNodeId(i);
2045     if (new_node == nullptr) continue;
2046     parent->name_map[new_node->name()] = new_node;
2047   }
2048 
2049   // Populate 'outputs'.
2050   DCHECK_LE(loop_outputs.size(), num_loop_vars);
2051   for (int i = 0; i < loop_outputs.size(); ++i) {
2052     outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2053   }
2054 }
2055 
2056 }  // namespace
2057 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2058 
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2059 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2060                     TF_Output* outputs) {
2061 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2062   status->status = tensorflow::errors::Unimplemented(
2063       "Creating while loops is not supported on mobile. File a bug at "
2064       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2065       "important to you");
2066 #else
2067   // If it appears the caller created or modified `params`, don't free resources
2068   if (!ValidateConstWhileParams(*params, status)) return;
2069   TF_FinishWhileHelper(params, status, outputs);
2070   FreeWhileResources(params);
2071 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2072 }
2073 
TF_AbortWhile(const TF_WhileParams * params)2074 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2075 
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2076 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2077                      TF_Output* dx, TF_Status* status, TF_Output* dy) {
2078   TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2079 }
2080 
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2081 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2082                                int ny, TF_Output* x, int nx, TF_Output* dx,
2083                                TF_Status* status, TF_Output* dy) {
2084 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2085   status->status = tensorflow::errors::Unimplemented(
2086       "Adding gradients is not supported on mobile. File a bug at "
2087       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2088       "important to you");
2089 #else
2090   std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2091   std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2092   std::vector<tensorflow::Output> dy_arg;
2093 
2094   {
2095     // We need to hold on to the lock while we have a scope that uses TF_Graph.
2096     mutex_lock graph_lock(g->mu);
2097 
2098     const int first_new_node_id = g->graph.num_node_ids();
2099 
2100     string prefix_cmp;
2101     const char* child_scope_name;
2102     if (prefix == nullptr) {
2103       child_scope_name = "gradients";
2104     } else {
2105       prefix_cmp = string(prefix) + "/";
2106       // The operation should fail if the provided name prefix has already been
2107       // used in this graph
2108       for (const auto& pair : g->name_map) {
2109         const string& name = pair.first;
2110         if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2111           status->status = InvalidArgument(
2112               "prefix [", prefix,
2113               "] conflicts with existing node in the graph named [", name, "]");
2114           return;
2115         }
2116       }
2117       child_scope_name = prefix;
2118     }
2119     tensorflow::Scope scope =
2120         NewInternalScope(&g->graph, &status->status, &g->refiner)
2121             .NewSubScope(child_scope_name);
2122 
2123     if (dx != nullptr) {
2124       std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2125       status->status =
2126           AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2127     } else {
2128       status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2129     }
2130 
2131     // Update g->name_map with the name_map from the scope, which will contain
2132     // the new gradient ops.
2133     for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2134       Node* n = g->graph.FindNodeId(i);
2135       if (n == nullptr) continue;
2136 
2137       // Adding the gradients to the graph can alter the prefix to prevent
2138       // name collisions only if this prefix has not been provided explicitly
2139       // by the user. If it was provided, assert that it remained intact.
2140       if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2141         status->status = tensorflow::errors::Internal(
2142             "BUG: The gradients prefix have been unexpectedly altered when "
2143             "adding the nodes to the graph. This is a bug. Please file an "
2144             "issue at https://github.com/tensorflow/tensorflow/issues.");
2145         return;
2146       }
2147       // We have a convoluted scheme here: Using the C++ graph construction API
2148       // to add potentially many nodes to the graph without running the checks
2149       // (such as uniqueness of the names of nodes) we run with other functions
2150       // that add a node to the graph (like TF_FinishOperation).
2151       if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2152         status->status = tensorflow::errors::Internal(
2153             "BUG: The API allowed construction of a graph with duplicate node "
2154             "names (",
2155             n->name(),
2156             "). This is a bug. Please file an issue at "
2157             "https://github.com/tensorflow/tensorflow/issues.");
2158       }
2159     }
2160   }
2161 
2162   // Unpack the results from grad_outputs_arg.
2163   TFOutputsFromOutputs(dy_arg, dy);
2164 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2165 }
2166 
2167 // TF_Session functions ----------------------------------------------
2168 
TF_Session(tensorflow::Session * s,TF_Graph * g)2169 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2170     : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2171 
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2172 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2173                           TF_Status* status) {
2174   Session* session;
2175   status->status = NewSession(opt->options, &session);
2176   if (status->status.ok()) {
2177     TF_Session* new_session = new TF_Session(session, graph);
2178     if (graph != nullptr) {
2179       mutex_lock l(graph->mu);
2180       graph->sessions[new_session] = "";
2181     }
2182     return new_session;
2183   } else {
2184     LOG(ERROR) << status->status;
2185     DCHECK_EQ(nullptr, session);
2186     return nullptr;
2187   }
2188 }
2189 
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2190 TF_Session* TF_LoadSessionFromSavedModel(
2191     const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2192     const char* export_dir, const char* const* tags, int tags_len,
2193     TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2194 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2195 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2196 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2197   status->status = tensorflow::errors::Unimplemented(
2198       "Loading a SavedModel is not supported on mobile. File a bug at "
2199       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2200       "important to you");
2201   return nullptr;
2202 #else
2203   mutex_lock l(graph->mu);
2204   if (!graph->name_map.empty()) {
2205     status->status = InvalidArgument("Graph is non-empty.");
2206     return nullptr;
2207   }
2208 
2209   RunOptions run_options_proto;
2210   if (run_options != nullptr && !run_options_proto.ParseFromArray(
2211                                     run_options->data, run_options->length)) {
2212     status->status = InvalidArgument("Unparseable RunOptions proto");
2213     return nullptr;
2214   }
2215 
2216   std::unordered_set<string> tag_set;
2217   for (int i = 0; i < tags_len; i++) {
2218     tag_set.insert(string(tags[i]));
2219   }
2220 
2221   tensorflow::SavedModelBundle bundle;
2222   status->status =
2223       tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2224                                  export_dir, tag_set, &bundle);
2225   if (!status->status.ok()) return nullptr;
2226 
2227   // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2228   // extends using GraphDefs. The Graph instance is different, but equivalent
2229   // to the one used to create the session.
2230   //
2231   // TODO(jhseu): When Session is modified to take Graphs instead of
2232   // GraphDefs, return the Graph generated in LoadSavedModel().
2233   TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2234   TF_ImportGraphDefResults results;
2235   GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2236                             import_opts, &results, status);
2237   TF_DeleteImportGraphDefOptions(import_opts);
2238   if (!status->status.ok()) return nullptr;
2239 
2240   if (meta_graph_def != nullptr) {
2241     status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2242     if (!status->status.ok()) return nullptr;
2243   }
2244 
2245   TF_Session* session = new TF_Session(bundle.session.release(), graph);
2246 
2247   graph->sessions[session] = "";
2248   session->last_num_graph_nodes = graph->graph.num_node_ids();
2249   return session;
2250 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2251 }
2252 
TF_CloseSession(TF_Session * s,TF_Status * status)2253 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2254   status->status = s->session->Close();
2255 }
2256 
TF_DeleteSession(TF_Session * s,TF_Status * status)2257 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2258   status->status = Status::OK();
2259   if (s == nullptr) return;
2260   TF_Graph* const graph = s->graph;
2261   if (graph != nullptr) {
2262     graph->mu.lock();
2263     graph->sessions.erase(s);
2264     const bool del = graph->delete_requested && graph->sessions.empty();
2265     graph->mu.unlock();
2266     if (del) delete graph;
2267   }
2268   delete s->session;
2269   delete s;
2270 }
2271 
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2272 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2273                    const TF_Output* inputs, TF_Tensor* const* input_values,
2274                    int ninputs, const TF_Output* outputs,
2275                    TF_Tensor** output_values, int noutputs,
2276                    const TF_Operation* const* target_opers, int ntargets,
2277                    TF_Buffer* run_metadata, TF_Status* status) {
2278   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2279   // directly, instead of requiring us to serialize to a GraphDef and
2280   // call Session::Extend().
2281   if (session->extend_before_run &&
2282       !ExtendSessionGraphHelper(session, status)) {
2283     return;
2284   }
2285 
2286   TF_Run_Setup(noutputs, output_values, status);
2287 
2288   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2289   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2290   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2291   for (int i = 0; i < ninputs; ++i) {
2292     input_pairs[i].first = OutputName(inputs[i]);
2293   }
2294 
2295   // Convert from TF_Output to string names.
2296   std::vector<string> output_names(noutputs);
2297   for (int i = 0; i < noutputs; ++i) {
2298     output_names[i] = OutputName(outputs[i]);
2299   }
2300 
2301   // Convert from TF_Operation* to string names.
2302   std::vector<string> target_names(ntargets);
2303   for (int i = 0; i < ntargets; ++i) {
2304     target_names[i] = target_opers[i]->node.name();
2305   }
2306 
2307   // Actually run.
2308   TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2309                 output_names, output_values, target_names, run_metadata,
2310                 status);
2311 }
2312 
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2313 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2314                          int ninputs, const TF_Output* outputs, int noutputs,
2315                          const TF_Operation* const* target_opers, int ntargets,
2316                          const char** handle, TF_Status* status) {
2317   *handle = nullptr;
2318 
2319   if (session->extend_before_run &&
2320       !ExtendSessionGraphHelper(session, status)) {
2321     return;
2322   }
2323 
2324   std::vector<string> input_names(ninputs);
2325   for (int i = 0; i < ninputs; ++i) {
2326     input_names[i] = OutputName(inputs[i]);
2327   }
2328 
2329   std::vector<string> output_names(noutputs);
2330   for (int i = 0; i < noutputs; ++i) {
2331     output_names[i] = OutputName(outputs[i]);
2332   }
2333 
2334   std::vector<string> target_names(ntargets);
2335   for (int i = 0; i < ntargets; ++i) {
2336     target_names[i] = target_opers[i]->node.name();
2337   }
2338 
2339   string new_handle;
2340   status->status = session->session->PRunSetup(input_names, output_names,
2341                                                target_names, &new_handle);
2342   if (status->status.ok()) {
2343     char* buf = new char[new_handle.size() + 1];
2344     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2345     *handle = buf;
2346   }
2347 }
2348 
TF_DeletePRunHandle(const char * handle)2349 void TF_DeletePRunHandle(const char* handle) {
2350   delete[] handle;
2351   // TODO(suharshs): Free up any resources held by the partial run state.
2352 }
2353 
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2354 void TF_SessionPRun(TF_Session* session, const char* handle,
2355                     const TF_Output* inputs, TF_Tensor* const* input_values,
2356                     int ninputs, const TF_Output* outputs,
2357                     TF_Tensor** output_values, int noutputs,
2358                     const TF_Operation* const* target_opers, int ntargets,
2359                     TF_Status* status) {
2360   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2361   // directly, instead of requiring us to serialize to a GraphDef and
2362   // call Session::Extend().
2363   if (session->extend_before_run &&
2364       !ExtendSessionGraphHelper(session, status)) {
2365     return;
2366   }
2367 
2368   TF_Run_Setup(noutputs, output_values, status);
2369 
2370   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2371   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2372   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2373   for (int i = 0; i < ninputs; ++i) {
2374     input_pairs[i].first = OutputName(inputs[i]);
2375   }
2376 
2377   // Convert from TF_Output to string names.
2378   std::vector<string> output_names(noutputs);
2379   for (int i = 0; i < noutputs; ++i) {
2380     output_names[i] = OutputName(outputs[i]);
2381   }
2382 
2383   // Convert from TF_Operation* to string names.
2384   std::vector<string> target_names(ntargets);
2385   for (int i = 0; i < ntargets; ++i) {
2386     target_names[i] = target_opers[i]->node.name();
2387   }
2388 
2389   TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2390                 output_values, target_names, nullptr, status);
2391 }
2392 
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2393 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2394                                      TF_Tensor** result, TF_Status* status) {
2395   *result = nullptr;
2396   mutex_lock l(graph->mu);
2397   OutputTensor tensor(&output.oper->node, output.index);
2398   bool evaluated;
2399   Tensor result_tensor;
2400   status->status = EvaluateConstantTensor(
2401       tensor, graph->refiner, *graph->graph.op_registry(),
2402       graph->graph.versions().producer(), &evaluated, &result_tensor);
2403   if (evaluated) {
2404     DCHECK(status->status.ok());
2405     *result = TF_TensorFromTensor(result_tensor, &status->status);
2406     if (!status->status.ok()) evaluated = false;
2407   }
2408   return evaluated;
2409 }
2410 
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2411 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2412   tensorflow::OpList op_list;
2413   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2414     status->status = InvalidArgument("Unparseable OpList");
2415     return nullptr;
2416   }
2417   status->status = Status::OK();
2418   return new TF_ApiDefMap(op_list);
2419 }
2420 
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2421 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2422 
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2423 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2424                      size_t text_len, TF_Status* status) {
2425 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2426   status->status = tensorflow::errors::Unimplemented(
2427       "ApiDefMap is not supported on mobile.");
2428 #else
2429   mutex_lock l(api_def_map->lock);
2430   if (api_def_map->update_docs_called) {
2431     status->status = FailedPrecondition(
2432         "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2433         "called.");
2434     return;
2435   }
2436   string api_def_text(text, text_len);
2437   status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2438 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2439 }
2440 
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2441 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2442                            size_t name_len, TF_Status* status) {
2443 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2444   status->status = tensorflow::errors::Unimplemented(
2445       "ApiDefMap is not supported on mobile.");
2446   return nullptr;
2447 #else
2448   mutex_lock l(api_def_map->lock);
2449   if (!api_def_map->update_docs_called) {
2450     api_def_map->api_def_map.UpdateDocs();
2451     api_def_map->update_docs_called = true;
2452   }
2453   string name_str(name, name_len);
2454   const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2455   if (api_def == nullptr) {
2456     return nullptr;
2457   }
2458 
2459   TF_Buffer* ret = TF_NewBuffer();
2460   status->status = MessageToBuffer(*api_def, ret);
2461   if (!status->status.ok()) {
2462     TF_DeleteBuffer(ret);
2463     return nullptr;
2464   }
2465   return ret;
2466 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2467 }
2468 
TF_GetAllRegisteredKernels(TF_Status * status)2469 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2470   tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2471   TF_Buffer* ret = TF_NewBuffer();
2472   status->status = MessageToBuffer(kernel_list, ret);
2473   if (!status->status.ok()) {
2474     TF_DeleteBuffer(ret);
2475     return nullptr;
2476   }
2477   return ret;
2478 }
2479 
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2480 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2481   tensorflow::KernelList kernel_list =
2482       tensorflow::GetRegisteredKernelsForOp(name);
2483   TF_Buffer* ret = TF_NewBuffer();
2484   status->status = MessageToBuffer(kernel_list, ret);
2485   if (!status->status.ok()) {
2486     TF_DeleteBuffer(ret);
2487     return nullptr;
2488   }
2489   return ret;
2490 }
2491 
TF_UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)2492 void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2493                    TF_Status* status) {
2494   using tensorflow::RecordMutation;
2495   mutex_lock l(graph->mu);
2496   tensorflow::shape_inference::InferenceContext* ic =
2497       graph->refiner.GetContext(&new_src.oper->node);
2498 
2499   if (ic->num_outputs() <= new_src.index) {
2500     status->status = tensorflow::errors::OutOfRange(
2501         "Cannot update edge. Output index [", new_src.index,
2502         "] is greater than the number of total outputs [", ic->num_outputs(),
2503         "].");
2504     return;
2505   }
2506   tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
2507 
2508   tensorflow::shape_inference::InferenceContext* ic_dst =
2509       graph->refiner.GetContext(&dst.oper->node);
2510   if (ic_dst->num_inputs() <= dst.index) {
2511     status->status = tensorflow::errors::OutOfRange(
2512         "Cannot update edge. Input index [", dst.index,
2513         "] is greater than the number of total inputs [", ic_dst->num_inputs(),
2514         "].");
2515     return;
2516   }
2517   if (!ic_dst->MergeInput(dst.index, shape)) {
2518     status->status = tensorflow::errors::InvalidArgument(
2519         "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
2520         " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
2521     return;
2522   }
2523   status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
2524                                            &dst.oper->node, dst.index);
2525 
2526   if (TF_GetCode(status) == TF_OK) {
2527     // This modification only updates the destination node for
2528     // the purposes of running this graph in a session. Thus, we don't
2529     // record the source node as being modified.
2530     RecordMutation(graph, *dst.oper, "updating input tensor");
2531   }
2532 }
2533 
2534 // TF_Server functions ----------------------------------------------
2535 
2536 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2537 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2538     : target(server->target()), server(std::move(server)) {}
2539 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2540 
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2541 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2542                         TF_Status* status) {
2543 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2544   status->status = tensorflow::errors::Unimplemented(
2545       "Server functionality is not supported on mobile");
2546   return nullptr;
2547 #else
2548   tensorflow::ServerDef server_def;
2549   if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2550     status->status = InvalidArgument(
2551         "Could not parse provided bytes into a ServerDef protocol buffer");
2552     return nullptr;
2553   }
2554 
2555   std::unique_ptr<tensorflow::ServerInterface> out_server;
2556   status->status = tensorflow::NewServer(server_def, &out_server);
2557   if (!status->status.ok()) return nullptr;
2558 
2559   return new TF_Server(std::move(out_server));
2560 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2561 }
2562 
TF_ServerStart(TF_Server * server,TF_Status * status)2563 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2564 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2565   status->status = tensorflow::errors::Unimplemented(
2566       "Server functionality is not supported on mobile");
2567 #else
2568   status->status = server->server->Start();
2569 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2570 }
2571 
TF_ServerStop(TF_Server * server,TF_Status * status)2572 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2573 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2574   status->status = tensorflow::errors::Unimplemented(
2575       "Server functionality is not supported on mobile");
2576 #else
2577   status->status = server->server->Stop();
2578 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2579 }
2580 
TF_ServerJoin(TF_Server * server,TF_Status * status)2581 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2582 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2583   status->status = tensorflow::errors::Unimplemented(
2584       "Server functionality is not supported on mobile");
2585 #else
2586   status->status = server->server->Join();
2587 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2588 }
2589 
TF_ServerTarget(TF_Server * server)2590 const char* TF_ServerTarget(TF_Server* server) {
2591 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2592   return nullptr;
2593 #else
2594   return server->target.c_str();
2595 #endif
2596 }
2597 
TF_DeleteServer(TF_Server * server)2598 void TF_DeleteServer(TF_Server* server) {
2599 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2600   delete server;
2601 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2602 }
2603 
TF_RegisterLogListener(void (* listener)(const char *))2604 void TF_RegisterLogListener(void (*listener)(const char*)) {
2605 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2606   tensorflow::logging::RegisterListener(listener);
2607 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2608 }
2609 
TF_RegisterFilesystemPlugin(const char * plugin_filename,TF_Status * status)2610 void TF_RegisterFilesystemPlugin(const char* plugin_filename,
2611                                  TF_Status* status) {
2612 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2613   status->status = tensorflow::errors::Unimplemented(
2614       "FileSystem plugin functionality is not supported on mobile");
2615 #else
2616   status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
2617 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2618 }
2619 
2620 }  // end extern "C"
2621