• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h"  // NOLINT
26 
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
29 #include "tensorflow/cc/framework/gradients.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/cc/framework/scope_internal.h"
32 #include "tensorflow/cc/ops/while_loop.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/framework/logging.h"
36 #include "tensorflow/core/framework/op_gen_lib.h"
37 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 #include "tensorflow/c/c_api_internal.h"
39 #include "tensorflow/c/tf_buffer_internal.h"
40 #include "tensorflow/c/tf_status_internal.h"
41 #include "tensorflow/c/tf_tensor.h"
42 #include "tensorflow/core/common_runtime/device_mgr.h"
43 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
44 #include "tensorflow/core/common_runtime/graph_constructor.h"
45 #include "tensorflow/core/common_runtime/shape_refiner.h"
46 #include "tensorflow/core/framework/allocation_description.pb.h"
47 #include "tensorflow/core/framework/kernel_def.pb.h"
48 #include "tensorflow/core/framework/log_memory.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op_kernel.h"
51 #include "tensorflow/core/framework/partial_tensor_shape.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
54 #include "tensorflow/core/framework/tensor_shape.h"
55 #include "tensorflow/core/framework/tensor_shape.pb.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/framework/versions.pb.h"
58 #include "tensorflow/core/graph/graph.h"
59 #include "tensorflow/core/graph/node_builder.h"
60 #include "tensorflow/core/graph/validate.h"
61 #include "tensorflow/core/lib/gtl/array_slice.h"
62 #include "tensorflow/core/platform/coding.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/mem.h"
65 #include "tensorflow/core/platform/mutex.h"
66 #include "tensorflow/core/platform/protobuf.h"
67 #include "tensorflow/core/platform/status.h"
68 #include "tensorflow/core/platform/str_util.h"
69 #include "tensorflow/core/platform/strcat.h"
70 #include "tensorflow/core/platform/stringpiece.h"
71 #include "tensorflow/core/platform/thread_annotations.h"
72 #include "tensorflow/core/platform/types.h"
73 #include "tensorflow/core/public/session.h"
74 #include "tensorflow/core/public/version.h"
75 
76 // The implementation below is at the top level instead of the
77 // brain namespace because we are defining 'extern "C"' functions.
78 using tensorflow::AllocationDescription;
79 using tensorflow::AttrValueMap;
80 using tensorflow::DataType;
81 using tensorflow::ExtendSessionGraphHelper;
82 using tensorflow::Graph;
83 using tensorflow::GraphDef;
84 using tensorflow::mutex_lock;
85 using tensorflow::NameRangeMap;
86 using tensorflow::NameRangesForNode;
87 using tensorflow::NewSession;
88 using tensorflow::Node;
89 using tensorflow::NodeBuilder;
90 using tensorflow::NodeDef;
91 using tensorflow::OpDef;
92 using tensorflow::OpRegistry;
93 using tensorflow::OutputTensor;
94 using tensorflow::PartialTensorShape;
95 using tensorflow::RunMetadata;
96 using tensorflow::RunOptions;
97 using tensorflow::Session;
98 using tensorflow::Status;
99 using tensorflow::string;
100 using tensorflow::Tensor;
101 using tensorflow::TensorBuffer;
102 using tensorflow::TensorId;
103 using tensorflow::TensorShape;
104 using tensorflow::TensorShapeProto;
105 using tensorflow::VersionDef;
106 using tensorflow::errors::FailedPrecondition;
107 using tensorflow::errors::InvalidArgument;
108 using tensorflow::errors::OutOfRange;
109 using tensorflow::gtl::ArraySlice;
110 using tensorflow::strings::StrCat;
111 
112 extern "C" {
113 
114 // --------------------------------------------------------------------------
TF_Version()115 const char* TF_Version() { return TF_VERSION_STRING; }
116 
117 // --------------------------------------------------------------------------
118 
119 // --------------------------------------------------------------------------
TF_NewSessionOptions()120 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)121 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
122 
TF_SetTarget(TF_SessionOptions * options,const char * target)123 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
124   options->options.target = target;
125 }
126 
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)127 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
128                   size_t proto_len, TF_Status* status) {
129   if (!options->options.config.ParseFromArray(proto, proto_len)) {
130     status->status = InvalidArgument("Unparseable ConfigProto");
131   }
132 }
133 
TF_TensorFromProto(const TF_Buffer * from,TF_Tensor * to,TF_Status * status)134 void TF_TensorFromProto(const TF_Buffer* from, TF_Tensor* to,
135                         TF_Status* status) {
136   TF_SetStatus(status, TF_OK, "");
137   tensorflow::TensorProto from_tensor_proto;
138   status->status = BufferToMessage(from, &from_tensor_proto);
139   if (!status->status.ok()) {
140     return;
141   }
142   status->status =
143       tensorflow::down_cast<tensorflow::TensorInterface*>(to->tensor)
144           ->FromProto(from_tensor_proto);
145 }
146 // --------------------------------------------------------------------------
147 
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)148 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
149                                               TF_Status* status) {
150   Session* session;
151   status->status = NewSession(opt->options, &session);
152   if (status->status.ok()) {
153     return new TF_DeprecatedSession({session});
154   } else {
155     DCHECK_EQ(nullptr, session);
156     return nullptr;
157   }
158 }
159 
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)160 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
161   status->status = s->session->Close();
162 }
163 
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)164 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
165   status->status = ::tensorflow::OkStatus();
166   if (s == nullptr) return;
167   delete s->session;
168   delete s;
169 }
170 
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)171 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
172                     size_t proto_len, TF_Status* status) {
173   GraphDef g;
174   if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
175     status->status = InvalidArgument("Invalid GraphDef");
176     return;
177   }
178   status->status = s->session->Extend(g);
179 }
180 
181 }  // end extern "C"
182 
183 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)184 static void TF_Reset_Helper(const TF_SessionOptions* opt,
185                             const char** containers, int ncontainers,
186                             TF_Status* status) {
187   std::vector<string> container_names(ncontainers);
188   for (int i = 0; i < ncontainers; ++i) {
189     container_names[i] = containers[i];
190   }
191 
192   status->status = Reset(opt->options, container_names);
193 }
194 
195 extern "C" {
196 
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)197 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
198               int ncontainers, TF_Status* status) {
199   TF_Reset_Helper(opt, containers, ncontainers, status);
200 }
201 
202 }  // end extern "C"
203 
204 namespace tensorflow {
205 
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)206 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
207                     const char* mutation_type) {
208   // If any session has already run this node_id, mark this session as
209   // unrunnable.
210   for (auto it : graph->sessions) {
211     mutex_lock session_lock(it.first->mu);
212     if (it.first->last_num_graph_nodes > op.node.id()) {
213       it.second = strings::StrCat(
214           "Operation '", op.node.DebugString(), "' was changed by ",
215           mutation_type,
216           " after it was run by a session. This mutation will have no effect, "
217           "and will trigger an error in the future. Either don't modify "
218           "nodes after running them or create a new session.");
219     }
220   }
221 }
222 
223 namespace {
224 
225 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)226 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
227     tensorflow::shape_inference::InferenceContext* ic, int num_dims,
228     const int64_t* dims) {
229   if (num_dims != -1) {
230     std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
231     dim_vec.reserve(num_dims);
232     for (int i = 0; i < num_dims; ++i) {
233       dim_vec.push_back(ic->MakeDim(dims[i]));
234     }
235     return ic->MakeShape(dim_vec);
236   } else {
237     return ic->UnknownShape();
238   }
239 }
240 
241 }  // namespace
242 
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)243 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
244                                            int num_shapes_and_types,
245                                            const int64_t** shapes,
246                                            const int* ranks,
247                                            const TF_DataType* types,
248                                            TF_Status* status) {
249   Node* node = &output.oper->node;
250 
251   mutex_lock l(graph->mu);
252   tensorflow::shape_inference::InferenceContext* ic =
253       graph->refiner.GetContext(node);
254   if (ic == nullptr) {
255     status->status =
256         InvalidArgument("Node ", node->name(), " was not found in the graph");
257     return;
258   }
259 
260   auto shape_and_type_vec =
261       std::vector<tensorflow::shape_inference::ShapeAndType>(
262           num_shapes_and_types);
263   for (int i = 0; i < num_shapes_and_types; ++i) {
264     tensorflow::shape_inference::ShapeHandle shape_handle =
265         ShapeHandleFromDims(ic, ranks[i], shapes[i]);
266     shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
267         shape_handle, static_cast<DataType>(types[i]));
268   }
269 
270   ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
271 }
272 
273 // Helpers for loading a TensorFlow plugin (a .so file).
274 Status LoadDynamicLibrary(const char* library_filename, void** result,
275                           const void** buf, size_t* len);
276 
277 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
278 // directly, instead of requiring us to serialize to a GraphDef and
279 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)280 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
281   if (session->graph != nullptr) {
282     // Take the graph lock before the session lock to avoid deadlock. This is
283     // safe since session->graph does not change.
284     session->graph->mu.lock();
285     mutex_lock session_lock(session->mu);
286     const Graph& graph = session->graph->graph;
287 
288     const string& mutation_warning = session->graph->sessions[session];
289     if (!mutation_warning.empty()) {
290       // TODO(b/74949947): turn this back into an error status
291       LOG(WARNING) << mutation_warning;
292       session->graph->sessions[session].clear();
293     }
294 
295     const auto num_nodes = graph.num_node_ids();
296     if (session->last_num_graph_nodes < num_nodes) {
297       // TODO(nolivia): check this on a subset of the graph instead of all of
298       // it.
299       status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
300       if (!status->status.ok()) {
301         session->graph->mu.unlock();
302         return false;
303       }
304 
305       GraphDef graph_def;
306       *graph_def.mutable_versions() = graph.versions();
307       // Fill graph_def with nodes with ids in the range
308       // [session->last_num_graph_nodes, num_nodes), that is the nodes
309       // added since the last TF_SessionRun() call.
310       for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
311         Node* const node = graph.FindNodeId(id);
312         if (node != nullptr && node->IsOp()) {
313           NodeDef* const node_def = graph_def.add_node();
314           *node_def = node->def();
315         }
316       }
317       *graph_def.mutable_library() = graph.flib_def().ToProto();
318       session->graph->mu.unlock();
319       status->status = session->session->Extend(std::move(graph_def));
320       if (!status->status.ok()) {
321         // Contract is we always delete input_values[i].
322         return false;
323       }
324       // Note: session->session is not modified if Extend() fails, so
325       // we only set last_num_graph_nodes if it succeeds.
326       session->last_num_graph_nodes = num_nodes;
327     } else {
328       session->graph->mu.unlock();
329     }
330   }
331   return true;
332 }
333 
334 }  // namespace tensorflow
335 
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)336 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
337                          TF_Status* status) {
338   status->status = ::tensorflow::OkStatus();
339   for (int i = 0; i < noutputs; ++i) {
340     c_outputs[i] = nullptr;
341   }
342 }
343 
344 // TF_TensorToTensorV1 decodes a string serialization to DT_RESOURCE.
345 // In the TFv1 convention, TF_Tensor can hold a string serialization of
346 // DT_RESOURCE. The string serialization is converted back to a
347 // ResourceHandle during Session run where the TF_Tensor is converted to a
348 // Tensor.
349 // TFv2 does not depend on this conversion. There is no matching
350 // TF_TensorFromTensorV1 because the conversion to string is performed by the
351 // python side of Session.
TF_TensorToTensorV1(const TF_Tensor * src,Tensor * dst)352 static Status TF_TensorToTensorV1(const TF_Tensor* src, Tensor* dst) {
353   Status status = TF_TensorToTensor(src, dst);
354   if (!status.ok()) {
355     return status;
356   }
357   if (dst->dtype() == tensorflow::DT_RESOURCE) {
358     const auto tensor_interface =
359         tensorflow::down_cast<const tensorflow::TensorInterface*>(src->tensor);
360 
361     if (dst->dims() != 0) {
362       return InvalidArgument(
363           "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
364           "shape ",
365           dst->shape().DebugString());
366     }
367     *dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, dst->shape());
368     if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
369             string(static_cast<const char*>(tensor_interface->Data()),
370                    tensor_interface->ByteSize()))) {
371       return InvalidArgument(
372           "Malformed TF_RESOURCE tensor: unable to parse resource handle");
373     }
374     return ::tensorflow::OkStatus();
375   }
376   return ::tensorflow::OkStatus();
377 }
378 
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)379 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
380                           std::vector<std::pair<string, Tensor>>* input_pairs,
381                           TF_Status* status) {
382   const int ninputs = input_pairs->size();
383   for (int i = 0; i < ninputs; ++i) {
384     status->status =
385         TF_TensorToTensorV1(c_inputs[i], &(*input_pairs)[i].second);
386     if (!status->status.ok()) return false;
387   }
388   return true;
389 }
390 
391 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
392 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)393 static TF_Tensor* EmptyTensor(TF_DataType dtype,
394                               const tensorflow::TensorShape& shape) {
395   static char empty;
396   int64_t nelems = 1;
397   std::vector<int64_t> dims;
398   dims.reserve(shape.dims());
399   for (int i = 0; i < shape.dims(); ++i) {
400     dims.push_back(shape.dim_size(i));
401     nelems *= shape.dim_size(i);
402   }
403   CHECK_EQ(nelems, 0);
404   return TF_NewTensor(
405       dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
406       reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
407 }
408 
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)409 static void TF_Run_Helper(
410     Session* session, const char* handle, const TF_Buffer* run_options,
411     // Input tensors
412     const std::vector<std::pair<string, Tensor>>& input_pairs,
413     // Output tensors
414     const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
415     // Target nodes
416     const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
417     TF_Status* status) {
418   const int noutputs = output_tensor_names.size();
419   std::vector<Tensor> outputs(noutputs);
420   Status result;
421 
422   if (handle == nullptr) {
423     RunOptions run_options_proto;
424     if (run_options != nullptr && !run_options_proto.ParseFromArray(
425                                       run_options->data, run_options->length)) {
426       status->status = InvalidArgument("Unparseable RunOptions proto");
427       return;
428     }
429     if (run_metadata != nullptr && run_metadata->data != nullptr) {
430       status->status =
431           InvalidArgument("Passing non-empty run_metadata is invalid.");
432       return;
433     }
434 
435     RunMetadata run_metadata_proto;
436     result = session->Run(run_options_proto, input_pairs, output_tensor_names,
437                           target_oper_names, &outputs, &run_metadata_proto);
438 
439     // Serialize back to upstream client, who now owns the new buffer
440     if (run_metadata != nullptr) {
441       status->status = MessageToBuffer(run_metadata_proto, run_metadata);
442       if (!status->status.ok()) return;
443     }
444   } else {
445     // NOTE(zongheng): PRun does not support RunOptions yet.
446     result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
447   }
448   if (!result.ok()) {
449     status->status = result;
450     return;
451   }
452 
453   // Store results in c_outputs[]
454   for (int i = 0; i < noutputs; ++i) {
455     const Tensor& src = outputs[i];
456     if (!src.IsInitialized() || src.NumElements() == 0) {
457       c_outputs[i] =
458           EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
459       continue;
460     }
461     c_outputs[i] = TF_TensorFromTensor(src, &status->status);
462     if (!status->status.ok()) return;
463   }
464 }
465 
466 extern "C" {
467 
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)468 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
469             // Input tensors
470             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
471             // Output tensors
472             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
473             // Target nodes
474             const char** c_target_oper_names, int ntargets,
475             TF_Buffer* run_metadata, TF_Status* status) {
476   TF_Run_Setup(noutputs, c_outputs, status);
477   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
478   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
479   for (int i = 0; i < ninputs; ++i) {
480     input_pairs[i].first = c_input_names[i];
481   }
482   std::vector<string> output_names(noutputs);
483   for (int i = 0; i < noutputs; ++i) {
484     output_names[i] = c_output_names[i];
485   }
486   std::vector<string> target_oper_names(ntargets);
487   for (int i = 0; i < ntargets; ++i) {
488     target_oper_names[i] = c_target_oper_names[i];
489   }
490   TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
491                 c_outputs, target_oper_names, run_metadata, status);
492 }
493 
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)494 void TF_PRunSetup(TF_DeprecatedSession* s,
495                   // Input names
496                   const char** c_input_names, int ninputs,
497                   // Output names
498                   const char** c_output_names, int noutputs,
499                   // Target nodes
500                   const char** c_target_oper_names, int ntargets,
501                   const char** handle, TF_Status* status) {
502   *handle = nullptr;
503 
504   std::vector<string> input_names(ninputs);
505   std::vector<string> output_names(noutputs);
506   std::vector<string> target_oper_names(ntargets);
507   for (int i = 0; i < ninputs; ++i) {
508     input_names[i] = c_input_names[i];
509   }
510   for (int i = 0; i < noutputs; ++i) {
511     output_names[i] = c_output_names[i];
512   }
513   for (int i = 0; i < ntargets; ++i) {
514     target_oper_names[i] = c_target_oper_names[i];
515   }
516   string new_handle;
517   status->status = s->session->PRunSetup(input_names, output_names,
518                                          target_oper_names, &new_handle);
519   if (status->status.ok()) {
520     char* buf = new char[new_handle.size() + 1];
521     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
522     *handle = buf;
523   }
524 }
525 
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)526 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
527              // Input tensors
528              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
529              // Output tensors
530              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
531              // Target nodes
532              const char** c_target_oper_names, int ntargets,
533              TF_Status* status) {
534   TF_Run_Setup(noutputs, c_outputs, status);
535   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
536   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
537   for (int i = 0; i < ninputs; ++i) {
538     input_pairs[i].first = c_input_names[i];
539   }
540 
541   std::vector<string> output_names(noutputs);
542   for (int i = 0; i < noutputs; ++i) {
543     output_names[i] = c_output_names[i];
544   }
545   std::vector<string> target_oper_names(ntargets);
546   for (int i = 0; i < ntargets; ++i) {
547     target_oper_names[i] = c_target_oper_names[i];
548   }
549   TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
550                 c_outputs, target_oper_names, nullptr, status);
551 }
552 
TF_LoadLibrary(const char * library_filename,TF_Status * status)553 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
554   TF_Library* lib_handle = new TF_Library;
555   status->status = tensorflow::LoadDynamicLibrary(
556       library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
557       &lib_handle->op_list.length);
558   if (!status->status.ok()) {
559     delete lib_handle;
560     return nullptr;
561   }
562   return lib_handle;
563 }
564 
TF_GetOpList(TF_Library * lib_handle)565 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
566 
TF_DeleteLibraryHandle(TF_Library * lib_handle)567 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
568   if (lib_handle == nullptr) return;
569   tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
570   delete lib_handle;
571 }
572 
TF_GetAllOpList()573 TF_Buffer* TF_GetAllOpList() {
574   std::vector<tensorflow::OpDef> op_defs;
575   tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
576   tensorflow::OpList op_list;
577   for (const auto& op : op_defs) {
578     *(op_list.add_op()) = op;
579   }
580   TF_Buffer* ret = TF_NewBuffer();
581   TF_CHECK_OK(MessageToBuffer(op_list, ret));
582   return ret;
583 }
584 
585 // --------------------------------------------------------------------------
586 // ListDevices & SessionListDevices API
587 
TF_DeleteDeviceList(TF_DeviceList * list)588 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
589 
TF_SessionListDevices(TF_Session * session,TF_Status * status)590 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
591   TF_DeviceList* response = new TF_DeviceList;
592   if (session && session->session)
593     status->status = session->session->ListDevices(&response->response);
594   return response;
595 }
596 
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)597 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
598                                                TF_Status* status) {
599   TF_DeviceList* response = new TF_DeviceList;
600   if (session && session->session)
601     status->status = session->session->ListDevices(&response->response);
602   return response;
603 }
604 
TF_DeviceListCount(const TF_DeviceList * list)605 int TF_DeviceListCount(const TF_DeviceList* list) {
606   return list->response.size();
607 }
608 
609 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
610   return_type method_name(const TF_DeviceList* list, const int index,     \
611                           TF_Status* status) {                            \
612     if (list == nullptr) {                                                \
613       status->status = InvalidArgument("list is null!");                  \
614       return err_val;                                                     \
615     }                                                                     \
616     if (index < 0 || index >= list->response.size()) {                    \
617       status->status = InvalidArgument("index out of bounds");            \
618       return err_val;                                                     \
619     }                                                                     \
620     status->status = ::tensorflow::OkStatus();                            \
621     return list->response[index].accessor;                                \
622   }
623 
624 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
625 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
626                      nullptr);
627 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
628 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
629 
630 #undef TF_DEVICELIST_METHOD
631 
632 }  // end extern "C"
633 
634 // --------------------------------------------------------------------------
635 // New Graph and Session API
636 
637 // Helper functions -----------------------------------------------------------
638 
639 namespace {
640 
ToOperation(Node * node)641 TF_Operation* ToOperation(Node* node) {
642   return static_cast<TF_Operation*>(static_cast<void*>(node));
643 }
644 
OutputName(const TF_Output & output)645 string OutputName(const TF_Output& output) {
646   return StrCat(output.oper->node.name(), ":", output.index);
647 }
648 
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)649 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
650                                           const char* attr_name,
651                                           TF_Status* status) {
652   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
653   if (attr == nullptr) {
654     status->status = InvalidArgument("Operation '", oper->node.name(),
655                                      "' has no attr named '", attr_name, "'.");
656   }
657   return attr;
658 }
659 
ToTensorId(const TF_Output & output)660 TensorId ToTensorId(const TF_Output& output) {
661   return TensorId(output.oper->node.name(), output.index);
662 }
663 
664 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)665 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
666                                                      int n) {
667   std::vector<tensorflow::Output> outputs(n);
668   for (int i = 0; i < n; ++i) {
669     outputs[i] =
670         tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
671   }
672   return outputs;
673 }
674 
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)675 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
676                           TF_Output* tf_outputs) {
677   for (int i = 0; i < outputs.size(); i++) {
678     tf_outputs[i].oper = ToOperation(outputs[i].node());
679     tf_outputs[i].index = outputs[i].index();
680   }
681 }
682 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
683 
684 }  // namespace
685 
686 // Shape functions -----------------------------------------------------------
687 
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)688 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
689                             const int64_t* dims, const int num_dims,
690                             TF_Status* status) {
691   Node* node = &output.oper->node;
692 
693   mutex_lock l(graph->mu);
694   tensorflow::shape_inference::InferenceContext* ic =
695       graph->refiner.GetContext(node);
696   if (ic == nullptr) {
697     status->status =
698         InvalidArgument("Node ", node->name(), " was not found in the graph");
699     return;
700   }
701   tensorflow::shape_inference::ShapeHandle new_shape =
702       tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
703   status->status = graph->refiner.SetShape(node, output.index, new_shape);
704 }
705 
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)706 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
707                              TF_Status* status) {
708   Node* node = &output.oper->node;
709 
710   mutex_lock l(graph->mu);
711   tensorflow::shape_inference::InferenceContext* ic =
712       graph->refiner.GetContext(node);
713   if (ic == nullptr) {
714     status->status =
715         InvalidArgument("Node ", node->name(), " was not found in the graph");
716     return -1;
717   }
718 
719   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
720 
721   // Unknown rank means the number of dimensions is -1.
722   if (!ic->RankKnown(shape)) {
723     return -1;
724   }
725 
726   return ic->Rank(shape);
727 }
728 
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)729 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
730                             int num_dims, TF_Status* status) {
731   Node* node = &output.oper->node;
732 
733   mutex_lock l(graph->mu);
734   tensorflow::shape_inference::InferenceContext* ic =
735       graph->refiner.GetContext(node);
736   if (ic == nullptr) {
737     status->status =
738         InvalidArgument("Node ", node->name(), " was not found in the graph");
739     return;
740   }
741 
742   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
743 
744   int rank = -1;
745   if (ic->RankKnown(shape)) {
746     rank = ic->Rank(shape);
747   }
748 
749   if (num_dims != rank) {
750     status->status = InvalidArgument("Expected rank is ", num_dims,
751                                      " but actual rank is ", rank);
752     return;
753   }
754 
755   if (num_dims == 0) {
756     // Output shape is a scalar.
757     return;
758   }
759 
760   // Rank is greater than 0, so fill in the values, if known, and
761   // -1 for unknown values.
762   for (int i = 0; i < num_dims; ++i) {
763     auto dim = ic->Dim(shape, i);
764     int64_t value = -1;
765     if (ic->ValueKnown(dim)) {
766       value = ic->Value(dim);
767     }
768     dims[i] = value;
769   }
770 }
771 
772 // TF_OperationDescription functions ------------------------------------------
773 
774 extern "C" {
775 
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)776 TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
777                                                const char* op_type,
778                                                const char* oper_name)
779     TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
780   return new TF_OperationDescription(graph, op_type, oper_name);
781 }
782 
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)783 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
784                                          const char* oper_name) {
785   mutex_lock l(graph->mu);
786   return TF_NewOperationLocked(graph, op_type, oper_name);
787 }
788 
TF_SetDevice(TF_OperationDescription * desc,const char * device)789 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
790   desc->node_builder.Device(device);
791 }
792 
TF_AddInput(TF_OperationDescription * desc,TF_Output input)793 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
794   desc->node_builder.Input(&input.oper->node, input.index);
795 }
796 
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)797 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
798                      int num_inputs) {
799   std::vector<NodeBuilder::NodeOut> input_list;
800   input_list.reserve(num_inputs);
801   for (int i = 0; i < num_inputs; ++i) {
802     input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
803   }
804   desc->node_builder.Input(input_list);
805 }
806 
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)807 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
808   desc->node_builder.ControlInput(&input->node);
809 }
810 
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)811 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
812   desc->colocation_constraints.emplace(
813       StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
814 }
815 
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)816 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
817                       const void* value, size_t length) {
818   tensorflow::StringPiece s(static_cast<const char*>(value), length);
819   desc->node_builder.Attr(attr_name, s);
820 }
821 
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)822 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
823                           const void* const* values, const size_t* lengths,
824                           int num_values) {
825   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
826     desc->colocation_constraints.clear();
827     for (int i = 0; i < num_values; ++i) {
828       desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
829                                            lengths[i]);
830     }
831   } else {
832     std::vector<tensorflow::StringPiece> v;
833     v.reserve(num_values);
834     for (int i = 0; i < num_values; ++i) {
835       v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
836     }
837     desc->node_builder.Attr(attr_name, v);
838   }
839 }
840 
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)841 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
842                    int64_t value) {
843   desc->node_builder.Attr(attr_name, static_cast<int64_t>(value));
844 }
845 
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)846 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
847                        const int64_t* values, int num_values) {
848   desc->node_builder.Attr(
849       attr_name, ArraySlice<const int64_t>(
850                      reinterpret_cast<const int64_t*>(values), num_values));
851 }
852 
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)853 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
854                      float value) {
855   desc->node_builder.Attr(attr_name, value);
856 }
857 
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)858 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
859                          const float* values, int num_values) {
860   desc->node_builder.Attr(attr_name,
861                           ArraySlice<const float>(values, num_values));
862 }
863 
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)864 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
865                     unsigned char value) {
866   desc->node_builder.Attr(attr_name, static_cast<bool>(value));
867 }
868 
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)869 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
870                         const unsigned char* values, int num_values) {
871   std::unique_ptr<bool[]> b(new bool[num_values]);
872   for (int i = 0; i < num_values; ++i) {
873     b[i] = values[i];
874   }
875   desc->node_builder.Attr(attr_name,
876                           ArraySlice<const bool>(b.get(), num_values));
877 }
878 
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)879 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
880                     TF_DataType value) {
881   desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
882 }
883 
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)884 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
885                         const TF_DataType* values, int num_values) {
886   desc->node_builder.Attr(
887       attr_name, ArraySlice<const DataType>(
888                      reinterpret_cast<const DataType*>(values), num_values));
889 }
890 
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)891 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
892                            const char* placeholder) {
893   tensorflow::AttrValue attr_value;
894   attr_value.set_placeholder(placeholder);
895   desc->node_builder.Attr(attr_name, attr_value);
896 }
897 
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)898 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
899                         const char* value, size_t length) {
900   tensorflow::NameAttrList func_name;
901   func_name.set_name(string(value, value + length));
902   desc->node_builder.Attr(attr_name, func_name);
903 }
904 
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)905 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
906                      const int64_t* dims, int num_dims) {
907   PartialTensorShape shape;
908   if (num_dims >= 0) {
909     shape = PartialTensorShape(
910         ArraySlice<int64_t>(reinterpret_cast<const int64_t*>(dims), num_dims));
911   }
912   desc->node_builder.Attr(attr_name, shape);
913 }
914 
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)915 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
916                          const int64_t* const* dims, const int* num_dims,
917                          int num_shapes) {
918   std::vector<PartialTensorShape> shapes;
919   shapes.reserve(num_shapes);
920   for (int i = 0; i < num_shapes; ++i) {
921     if (num_dims[i] < 0) {
922       shapes.emplace_back();
923     } else {
924       shapes.emplace_back(ArraySlice<int64_t>(
925           reinterpret_cast<const int64_t*>(dims[i]), num_dims[i]));
926     }
927   }
928   desc->node_builder.Attr(attr_name, shapes);
929 }
930 
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)931 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
932                                 const char* attr_name, const void* proto,
933                                 size_t proto_len, TF_Status* status) {
934   // shape.ParseFromArray takes an int as length, this function takes size_t,
935   // make sure there is no information loss.
936   if (proto_len > std::numeric_limits<int>::max()) {
937     status->status = InvalidArgument(
938         "proto_len (", proto_len,
939         " bytes) is too large to be parsed by the protocol buffer library");
940     return;
941   }
942   TensorShapeProto shape;
943   if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
944     desc->node_builder.Attr(attr_name, shape);
945     status->status = ::tensorflow::OkStatus();
946   } else {
947     status->status = InvalidArgument("Unparseable TensorShapeProto");
948   }
949 }
950 
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)951 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
952                                     const char* attr_name,
953                                     const void* const* protos,
954                                     const size_t* proto_lens, int num_shapes,
955                                     TF_Status* status) {
956   std::vector<TensorShapeProto> shapes;
957   shapes.resize(num_shapes);
958   for (int i = 0; i < num_shapes; ++i) {
959     if (proto_lens[i] > std::numeric_limits<int>::max()) {
960       status->status = InvalidArgument(
961           "length of element ", i, " in the list (", proto_lens[i],
962           " bytes) is too large to be parsed by the protocol buffer library");
963       return;
964     }
965     if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
966       status->status =
967           InvalidArgument("Unparseable TensorShapeProto at index ", i);
968       return;
969     }
970   }
971   desc->node_builder.Attr(attr_name, shapes);
972   status->status = ::tensorflow::OkStatus();
973 }
974 
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)975 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
976                       TF_Tensor* value, TF_Status* status) {
977   Tensor t;
978   status->status = TF_TensorToTensor(value, &t);
979   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
980 }
981 
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)982 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
983                           TF_Tensor* const* values, int num_values,
984                           TF_Status* status) {
985   status->status = ::tensorflow::OkStatus();
986   std::vector<Tensor> t;
987   t.reserve(num_values);
988 
989   for (int i = 0; i < num_values && status->status.ok(); ++i) {
990     Tensor v;
991     status->status = TF_TensorToTensor(values[i], &v);
992     t.emplace_back(v);
993   }
994 
995   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
996 }
997 
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)998 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
999                           const void* proto, size_t proto_len,
1000                           TF_Status* status) {
1001   tensorflow::AttrValue attr_value;
1002   if (!attr_value.ParseFromArray(proto, proto_len)) {
1003     status->status = InvalidArgument("Unparseable AttrValue proto");
1004     return;
1005   }
1006 
1007   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1008     if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1009         attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1010       status->status =
1011           InvalidArgument("Expected \"list\" field for \"",
1012                           tensorflow::kColocationAttrName, "\" attribute");
1013       return;
1014     }
1015     desc->colocation_constraints.clear();
1016     for (const string& location : attr_value.list().s()) {
1017       desc->colocation_constraints.insert(location);
1018     }
1019   } else {
1020     desc->node_builder.Attr(attr_name, std::move(attr_value));
1021   }
1022 
1023   status->status = ::tensorflow::OkStatus();
1024 }
1025 
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1026 TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1027                                        TF_Status* status)
1028     TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1029   Node* ret = nullptr;
1030 
1031   if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1032     status->status = InvalidArgument("Duplicate node name in graph: '",
1033                                      desc->node_builder.node_name(), "'");
1034   } else {
1035     if (!desc->colocation_constraints.empty()) {
1036       desc->node_builder.Attr(
1037           tensorflow::kColocationAttrName,
1038           std::vector<string>(desc->colocation_constraints.begin(),
1039                               desc->colocation_constraints.end()));
1040     }
1041     status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1042                                                  /*consume=*/true);
1043 
1044     if (status->status.ok()) {
1045       // Run shape inference function for newly added node.
1046       status->status = desc->graph->refiner.AddNode(ret);
1047     }
1048     if (status->status.ok()) {
1049       // Add the node to the name-to-node mapping.
1050       desc->graph->name_map[ret->name()] = ret;
1051     } else if (ret != nullptr) {
1052       desc->graph->graph.RemoveNode(ret);
1053       ret = nullptr;
1054     }
1055   }
1056 
1057   delete desc;
1058 
1059   return ToOperation(ret);
1060 }
1061 
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1062 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1063                                  TF_Status* status) {
1064   mutex_lock l(desc->graph->mu);
1065   return TF_FinishOperationLocked(desc, status);
1066 }
1067 
1068 // TF_Operation functions
1069 // ----------------------------------------------------------
1070 
TF_OperationName(TF_Operation * oper)1071 const char* TF_OperationName(TF_Operation* oper) {
1072   return oper->node.name().c_str();
1073 }
1074 
TF_OperationOpType(TF_Operation * oper)1075 const char* TF_OperationOpType(TF_Operation* oper) {
1076   return oper->node.type_string().c_str();
1077 }
1078 
TF_OperationDevice(TF_Operation * oper)1079 const char* TF_OperationDevice(TF_Operation* oper) {
1080   return oper->node.requested_device().c_str();
1081 }
1082 
TF_OperationNumOutputs(TF_Operation * oper)1083 int TF_OperationNumOutputs(TF_Operation* oper) {
1084   return oper->node.num_outputs();
1085 }
1086 
TF_OperationOutputType(TF_Output oper_out)1087 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1088   return static_cast<TF_DataType>(
1089       oper_out.oper->node.output_type(oper_out.index));
1090 }
1091 
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1092 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1093                                  TF_Status* status) {
1094   NameRangeMap name_ranges;
1095   status->status =
1096       NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1097   if (!status->status.ok()) return -1;
1098   auto iter = name_ranges.find(arg_name);
1099   if (iter == name_ranges.end()) {
1100     status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1101     return -1;
1102   }
1103   return iter->second.second - iter->second.first;
1104 }
1105 
TF_OperationNumInputs(TF_Operation * oper)1106 int TF_OperationNumInputs(TF_Operation* oper) {
1107   return oper->node.num_inputs();
1108 }
1109 
TF_OperationInputType(TF_Input oper_in)1110 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1111   return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1112 }
1113 
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1114 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1115                                 TF_Status* status) {
1116   NameRangeMap name_ranges;
1117   status->status =
1118       NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1119   if (!status->status.ok()) return -1;
1120   auto iter = name_ranges.find(arg_name);
1121   if (iter == name_ranges.end()) {
1122     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1123     return -1;
1124   }
1125   return iter->second.second - iter->second.first;
1126 }
1127 
TF_OperationInput(TF_Input oper_in)1128 TF_Output TF_OperationInput(TF_Input oper_in) {
1129   const tensorflow::Edge* edge;
1130   Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1131   if (!s.ok()) {
1132     return {nullptr, -1};
1133   }
1134 
1135   return {ToOperation(edge->src()), edge->src_output()};
1136 }
1137 
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1138 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1139                            int max_inputs) {
1140   for (auto* edge : oper->node.in_edges()) {
1141     if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1142       inputs[edge->dst_input()] = {ToOperation(edge->src()),
1143                                    edge->src_output()};
1144     }
1145   }
1146 }
1147 
TF_OperationOutputNumConsumers(TF_Output oper_out)1148 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1149   int count = 0;
1150   for (const auto* edge : oper_out.oper->node.out_edges()) {
1151     if (edge->src_output() == oper_out.index) {
1152       ++count;
1153     }
1154   }
1155   return count;
1156 }
1157 
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1158 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1159                                 int max_consumers) {
1160   int count = 0;
1161   for (const auto* edge : oper_out.oper->node.out_edges()) {
1162     if (edge->src_output() == oper_out.index) {
1163       if (count < max_consumers) {
1164         consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1165       }
1166       ++count;
1167     }
1168   }
1169   return count;
1170 }
1171 
TF_OperationNumControlInputs(TF_Operation * oper)1172 int TF_OperationNumControlInputs(TF_Operation* oper) {
1173   int count = 0;
1174   for (const auto* edge : oper->node.in_edges()) {
1175     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1176       ++count;
1177     }
1178   }
1179   return count;
1180 }
1181 
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1182 int TF_OperationGetControlInputs(TF_Operation* oper,
1183                                  TF_Operation** control_inputs,
1184                                  int max_control_inputs) {
1185   int count = 0;
1186   for (const auto* edge : oper->node.in_edges()) {
1187     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1188       if (count < max_control_inputs) {
1189         control_inputs[count] = ToOperation(edge->src());
1190       }
1191       ++count;
1192     }
1193   }
1194   return count;
1195 }
1196 
TF_OperationNumControlOutputs(TF_Operation * oper)1197 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1198   int count = 0;
1199   for (const auto* edge : oper->node.out_edges()) {
1200     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1201       ++count;
1202     }
1203   }
1204   return count;
1205 }
1206 
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1207 int TF_OperationGetControlOutputs(TF_Operation* oper,
1208                                   TF_Operation** control_outputs,
1209                                   int max_control_outputs) {
1210   int count = 0;
1211   for (const auto* edge : oper->node.out_edges()) {
1212     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1213       if (count < max_control_outputs) {
1214         control_outputs[count] = ToOperation(edge->dst());
1215       }
1216       ++count;
1217     }
1218   }
1219   return count;
1220 }
1221 
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1222 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1223                                             const char* attr_name,
1224                                             TF_Status* status) {
1225   TF_AttrMetadata metadata;
1226   const auto* attr = GetAttrValue(oper, attr_name, status);
1227   if (!status->status.ok()) return metadata;
1228   switch (attr->value_case()) {
1229 #define SINGLE_CASE(kK, attr_type, size_expr) \
1230   case tensorflow::AttrValue::kK:             \
1231     metadata.is_list = 0;                     \
1232     metadata.list_size = -1;                  \
1233     metadata.type = attr_type;                \
1234     metadata.total_size = size_expr;          \
1235     break;
1236 
1237     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1238     SINGLE_CASE(kI, TF_ATTR_INT, -1);
1239     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1240     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1241     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1242     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1243                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1244     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1245 #undef SINGLE_CASE
1246 
1247     case tensorflow::AttrValue::kList:
1248       metadata.is_list = 1;
1249       metadata.list_size = 0;
1250       metadata.total_size = -1;
1251 #define LIST_CASE(field, attr_type, ...)              \
1252   if (attr->list().field##_size() > 0) {              \
1253     metadata.type = attr_type;                        \
1254     metadata.list_size = attr->list().field##_size(); \
1255     __VA_ARGS__;                                      \
1256     break;                                            \
1257   }
1258 
1259       LIST_CASE(
1260           s, TF_ATTR_STRING, metadata.total_size = 0;
1261           for (int i = 0; i < attr->list().s_size();
1262                ++i) { metadata.total_size += attr->list().s(i).size(); });
1263       LIST_CASE(i, TF_ATTR_INT);
1264       LIST_CASE(f, TF_ATTR_FLOAT);
1265       LIST_CASE(b, TF_ATTR_BOOL);
1266       LIST_CASE(type, TF_ATTR_TYPE);
1267       LIST_CASE(
1268           shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1269           for (int i = 0; i < attr->list().shape_size(); ++i) {
1270             const auto& s = attr->list().shape(i);
1271             metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1272           });
1273       LIST_CASE(tensor, TF_ATTR_TENSOR);
1274       LIST_CASE(tensor, TF_ATTR_FUNC);
1275 #undef LIST_CASE
1276       // All lists empty, determine the type from the OpDef.
1277       if (metadata.list_size == 0) {
1278         for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1279           const auto& a = oper->node.op_def().attr(i);
1280           if (a.name() != attr_name) continue;
1281           const string& typestr = a.type();
1282           if (typestr == "list(string)") {
1283             metadata.type = TF_ATTR_STRING;
1284           } else if (typestr == "list(int)") {
1285             metadata.type = TF_ATTR_INT;
1286           } else if (typestr == "list(float)") {
1287             metadata.type = TF_ATTR_FLOAT;
1288           } else if (typestr == "list(bool)") {
1289             metadata.type = TF_ATTR_BOOL;
1290           } else if (typestr == "list(type)") {
1291             metadata.type = TF_ATTR_TYPE;
1292           } else if (typestr == "list(shape)") {
1293             metadata.type = TF_ATTR_SHAPE;
1294           } else if (typestr == "list(tensor)") {
1295             metadata.type = TF_ATTR_TENSOR;
1296           } else if (typestr == "list(func)") {
1297             metadata.type = TF_ATTR_FUNC;
1298           } else {
1299             status->status = InvalidArgument(
1300                 "Attribute '", attr_name,
1301                 "' has an empty value of an unrecognized type '", typestr, "'");
1302             return metadata;
1303           }
1304         }
1305       }
1306       break;
1307 
1308     case tensorflow::AttrValue::kPlaceholder:
1309       metadata.is_list = 0;
1310       metadata.list_size = -1;
1311       metadata.type = TF_ATTR_PLACEHOLDER;
1312       metadata.total_size = -1;
1313       break;
1314 
1315     case tensorflow::AttrValue::kFunc:
1316       metadata.is_list = 0;
1317       metadata.list_size = -1;
1318       metadata.type = TF_ATTR_FUNC;
1319       metadata.total_size = -1;
1320       break;
1321 
1322     case tensorflow::AttrValue::VALUE_NOT_SET:
1323       status->status =
1324           InvalidArgument("Attribute '", attr_name, "' has no value set");
1325       break;
1326   }
1327   return metadata;
1328 }
1329 
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1330 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1331                                void* value, size_t max_length,
1332                                TF_Status* status) {
1333   const auto* attr = GetAttrValue(oper, attr_name, status);
1334   if (!status->status.ok()) return;
1335   if (attr->value_case() != tensorflow::AttrValue::kS) {
1336     status->status =
1337         InvalidArgument("Attribute '", attr_name, "' is not a string");
1338     return;
1339   }
1340   if (max_length <= 0) {
1341     return;
1342   }
1343   const auto& s = attr->s();
1344   std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1345 }
1346 
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1347 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1348                                    void** values, size_t* lengths,
1349                                    int max_values, void* storage,
1350                                    size_t storage_size, TF_Status* status) {
1351   const auto* attr = GetAttrValue(oper, attr_name, status);
1352   if (!status->status.ok()) return;
1353   if (attr->value_case() != tensorflow::AttrValue::kList) {
1354     status->status =
1355         InvalidArgument("Value for '", attr_name, "' is not a list");
1356     return;
1357   }
1358   const auto len = std::min(max_values, attr->list().s_size());
1359   char* p = static_cast<char*>(storage);
1360   for (int i = 0; i < len; ++i) {
1361     const string& s = attr->list().s(i);
1362     values[i] = p;
1363     lengths[i] = s.size();
1364     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1365       status->status = InvalidArgument(
1366           "Not enough storage to hold the requested list of strings");
1367       return;
1368     }
1369     memcpy(values[i], s.data(), s.size());
1370     p += s.size();
1371   }
1372 }
1373 
1374 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1375   void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1376             TF_Status* status) {                                             \
1377     cpp_type v;                                                              \
1378     status->status =                                                         \
1379         tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1380     if (!status->status.ok()) return;                                        \
1381     *value = static_cast<c_type>(v);                                         \
1382   }                                                                          \
1383   void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1384                   int max_values, TF_Status* status) {                       \
1385     const auto* attr = GetAttrValue(oper, attr_name, status);                \
1386     if (!status->status.ok()) return;                                        \
1387     if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1388       status->status =                                                       \
1389           InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1390       return;                                                                \
1391     }                                                                        \
1392     const auto len = std::min(max_values, attr->list().list_field##_size()); \
1393     for (int i = 0; i < len; ++i) {                                          \
1394       values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1395     }                                                                        \
1396   }
1397 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, int64_t, i);
1398 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1399 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1400 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1401 #undef DEFINE_GETATTR
1402 
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1403 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1404                               int64_t* value, int num_dims, TF_Status* status) {
1405   PartialTensorShape shape;
1406   status->status =
1407       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1408   if (!status->status.ok()) return;
1409   auto len = std::min(shape.dims(), num_dims);
1410   for (int i = 0; i < len; ++i) {
1411     value[i] = shape.dim_size(i);
1412   }
1413 }
1414 
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1415 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1416                                   int64_t** dims, int* num_dims, int num_shapes,
1417                                   int64_t* storage, int storage_size,
1418                                   TF_Status* status) {
1419   std::vector<PartialTensorShape> shapes;
1420   status->status =
1421       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1422   if (!status->status.ok()) return;
1423   auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1424   int64_t* p = storage;
1425   int storage_left = storage_size;
1426   for (int i = 0; i < len; ++i) {
1427     // shapes[i].dims() == -1 for shapes with an unknown rank.
1428     int64_t n = shapes[i].dims();
1429     num_dims[i] = n;
1430     dims[i] = p;
1431     if (n < 0) {
1432       continue;
1433     }
1434     if (storage_left < n) {
1435       status->status = InvalidArgument(
1436           "Not enough storage to hold the requested list of shapes");
1437       return;
1438     }
1439     storage_left -= n;
1440     for (int j = 0; j < n; ++j, ++p) {
1441       *p = shapes[i].dim_size(j);
1442     }
1443   }
1444 }
1445 
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1446 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1447                                          const char* attr_name,
1448                                          TF_Buffer* value, TF_Status* status) {
1449   const auto* attr = GetAttrValue(oper, attr_name, status);
1450   if (!status->status.ok()) return;
1451   if (attr->value_case() != tensorflow::AttrValue::kShape) {
1452     status->status =
1453         InvalidArgument("Value for '", attr_name, "' is not a shape.");
1454     return;
1455   }
1456   status->status = MessageToBuffer(attr->shape(), value);
1457 }
1458 
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1459 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1460                                              const char* attr_name,
1461                                              TF_Buffer** values, int max_values,
1462                                              TF_Status* status) {
1463   const auto* attr = GetAttrValue(oper, attr_name, status);
1464   if (!status->status.ok()) return;
1465   if (attr->value_case() != tensorflow::AttrValue::kList) {
1466     status->status =
1467         InvalidArgument("Value for '", attr_name, "' is not a list");
1468     return;
1469   }
1470   const auto len = std::min(max_values, attr->list().shape_size());
1471   for (int i = 0; i < len; ++i) {
1472     values[i] = TF_NewBuffer();
1473     status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1474     if (!status->status.ok()) {
1475       // Delete everything allocated to far, the operation has failed.
1476       for (int j = 0; j <= i; ++j) {
1477         TF_DeleteBuffer(values[j]);
1478       }
1479       return;
1480     }
1481   }
1482 }
1483 
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1484 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1485                                TF_Tensor** value, TF_Status* status) {
1486   *value = nullptr;
1487   Tensor t;
1488   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1489   if (!status->status.ok()) return;
1490   *value = TF_TensorFromTensor(t, &status->status);
1491 }
1492 
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1493 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1494                                    TF_Tensor** values, int max_values,
1495                                    TF_Status* status) {
1496   std::vector<Tensor> ts;
1497   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1498   if (!status->status.ok()) return;
1499   const auto len = std::min(max_values, static_cast<int>(ts.size()));
1500   for (int i = 0; i < len; ++i) {
1501     values[i] = TF_TensorFromTensor(ts[i], &status->status);
1502   }
1503 }
1504 
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1505 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1506                                    TF_Buffer* output_attr_value,
1507                                    TF_Status* status) {
1508   const auto* attr = GetAttrValue(oper, attr_name, status);
1509   if (!status->status.ok()) return;
1510   status->status = MessageToBuffer(*attr, output_attr_value);
1511 }
1512 
TF_OperationGetNumAttrs(TF_Operation * oper)1513 int TF_OperationGetNumAttrs(TF_Operation* oper) {
1514   return oper->node.attrs().size();
1515 }
1516 
TF_OperationGetAttrNameLength(TF_Operation * oper,int i)1517 int TF_OperationGetAttrNameLength(TF_Operation* oper, int i) {
1518   auto attrs = oper->node.attrs();
1519   int count = 0;
1520   AttrValueMap::const_iterator it;
1521   for (it = attrs.begin(); it != attrs.end(); it++) {
1522     if (count == i) {
1523       return it->first.length();
1524     }
1525     count++;
1526   }
1527   return -1;
1528 }
1529 
TF_OperationGetAttrName(TF_Operation * oper,int i,char * output,TF_Status * status)1530 void TF_OperationGetAttrName(TF_Operation* oper, int i, char* output,
1531                              TF_Status* status) {
1532   auto attrs = oper->node.attrs();
1533   int count = 0;
1534   AttrValueMap::const_iterator it;
1535   for (it = attrs.begin(); it != attrs.end(); it++) {
1536     if (count == i) {
1537       strncpy(output, it->first.c_str(), it->first.length());
1538       status->status = ::tensorflow::OkStatus();
1539       return;
1540     }
1541     count++;
1542   }
1543   status->status = OutOfRange("Operation only has ", count,
1544                               " attributes, can't get the ", i, "th");
1545 }
1546 
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1547 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1548                            TF_Status* status) {
1549   status->status = MessageToBuffer(oper->node.def(), output_node_def);
1550 }
1551 
1552 // TF_Graph functions ---------------------------------------------------------
1553 
TF_Graph()1554 TF_Graph::TF_Graph()
1555     : graph(tensorflow::OpRegistry::Global()),
1556       refiner(graph.versions().producer(), graph.op_registry()),
1557       delete_requested(false),
1558       parent(nullptr),
1559       parent_inputs(nullptr) {
1560   // Tell the shape refiner to also run shape inference on functions.
1561   refiner.set_function_library_for_shape_inference(&graph.flib_def());
1562 }
1563 
TF_NewGraph()1564 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1565 
TF_DeleteGraph(TF_Graph * g)1566 void TF_DeleteGraph(TF_Graph* g) {
1567   if (g == nullptr) return;
1568   g->mu.lock();
1569   g->delete_requested = true;
1570   const bool del = g->sessions.empty();
1571   g->mu.unlock();
1572   if (del) delete g;
1573 }
1574 
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1575 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1576   mutex_lock l(graph->mu);
1577   auto iter = graph->name_map.find(oper_name);
1578   if (iter == graph->name_map.end()) {
1579     return nullptr;
1580   } else {
1581     return ToOperation(iter->second);
1582   }
1583 }
1584 
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1585 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1586   if (*pos == 0) {
1587     // Advance past the first sentinel nodes in every graph (the source & sink).
1588     *pos += 2;
1589   } else {
1590     // Advance to the next node.
1591     *pos += 1;
1592   }
1593 
1594   mutex_lock l(graph->mu);
1595   while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1596     Node* node = graph->graph.FindNodeId(*pos);
1597     // FindNodeId() returns nullptr for nodes that have been deleted.
1598     // We aren't currently allowing nodes to be deleted, but it is safer
1599     // to still check.
1600     if (node != nullptr) return ToOperation(node);
1601     *pos += 1;
1602   }
1603 
1604   // No more nodes.
1605   return nullptr;
1606 }
1607 
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1608 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1609                         TF_Status* status) {
1610   GraphDef def;
1611   {
1612     mutex_lock l(graph->mu);
1613     graph->graph.ToGraphDef(&def);
1614   }
1615   status->status = MessageToBuffer(def, output_graph_def);
1616 }
1617 
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1618 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1619                       TF_Buffer* output_op_def, TF_Status* status) {
1620   const OpDef* op_def;
1621   {
1622     mutex_lock l(graph->mu);
1623     status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1624     if (!status->status.ok()) return;
1625   }
1626   status->status = MessageToBuffer(*op_def, output_op_def);
1627 }
1628 
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1629 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1630                       TF_Status* status) {
1631   VersionDef versions;
1632   {
1633     mutex_lock l(graph->mu);
1634     versions = graph->graph.versions();
1635   }
1636   status->status = MessageToBuffer(versions, output_version_def);
1637 }
1638 
TF_NewImportGraphDefOptions()1639 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1640   return new TF_ImportGraphDefOptions;
1641 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1642 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1643   delete opts;
1644 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1645 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1646                                        const char* prefix) {
1647   opts->opts.prefix = prefix;
1648 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1649 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1650                                               const char* device) {
1651   opts->opts.default_device = device;
1652 }
1653 
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1654 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1655                                               unsigned char uniquify_names) {
1656   opts->opts.uniquify_names = uniquify_names;
1657 }
1658 
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1659 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1660                                                unsigned char uniquify_prefix) {
1661   opts->opts.uniquify_prefix = uniquify_prefix;
1662 }
1663 
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1664 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1665                                              const char* src_name,
1666                                              int src_index, TF_Output dst) {
1667   opts->tensor_id_data.push_back(src_name);
1668   const string& src_name_str = opts->tensor_id_data.back();
1669   // We don't need to store dst's name in tensor_id_data, since `dst` must
1670   // outlive the ImportGraphDef call.
1671   opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1672 }
1673 
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1674 void TF_ImportGraphDefOptionsRemapControlDependency(
1675     TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1676   opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1677       TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1678 }
1679 
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1680 extern void TF_ImportGraphDefOptionsAddControlDependency(
1681     TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1682   opts->opts.control_dependencies.push_back(oper->node.name());
1683 }
1684 
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1685 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1686                                              const char* oper_name, int index) {
1687   opts->tensor_id_data.push_back(oper_name);
1688   const string& oper_name_str = opts->tensor_id_data.back();
1689   opts->opts.return_tensors.emplace_back(oper_name_str, index);
1690 }
1691 
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1692 int TF_ImportGraphDefOptionsNumReturnOutputs(
1693     const TF_ImportGraphDefOptions* opts) {
1694   return opts->opts.return_tensors.size();
1695 }
1696 
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1697 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1698                                                 const char* oper_name) {
1699   opts->opts.return_nodes.push_back(oper_name);
1700 }
1701 
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1702 int TF_ImportGraphDefOptionsNumReturnOperations(
1703     const TF_ImportGraphDefOptions* opts) {
1704   return opts->opts.return_nodes.size();
1705 }
1706 
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1707 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1708                                            int* num_outputs,
1709                                            TF_Output** outputs) {
1710   *num_outputs = results->return_tensors.size();
1711   *outputs = results->return_tensors.data();
1712 }
1713 
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1714 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1715                                               int* num_opers,
1716                                               TF_Operation*** opers) {
1717   *num_opers = results->return_nodes.size();
1718   *opers = results->return_nodes.data();
1719 }
1720 
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1721 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1722     TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1723     const char*** src_names, int** src_indexes) {
1724   *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1725   *src_names = results->missing_unused_key_names.data();
1726   *src_indexes = results->missing_unused_key_indexes.data();
1727 }
1728 
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1729 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1730   delete results;
1731 }
1732 
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1733 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1734                                       const TF_ImportGraphDefOptions* opts,
1735                                       TF_ImportGraphDefResults* tf_results,
1736                                       TF_Status* status)
1737     TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1738   const int last_node_id = graph->graph.num_node_ids();
1739   tensorflow::ImportGraphDefResults results;
1740   status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1741                                               &graph->refiner, &results);
1742   if (!status->status.ok()) return;
1743 
1744   // Add new nodes to name_map
1745   for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1746     auto* node = graph->graph.FindNodeId(i);
1747     if (node != nullptr) graph->name_map[node->name()] = node;
1748   }
1749 
1750   // Populate return_tensors
1751   DCHECK(tf_results->return_tensors.empty());
1752   tf_results->return_tensors.resize(results.return_tensors.size());
1753   for (int i = 0; i < results.return_tensors.size(); ++i) {
1754     tf_results->return_tensors[i].oper =
1755         ToOperation(results.return_tensors[i].first);
1756     tf_results->return_tensors[i].index = results.return_tensors[i].second;
1757   }
1758 
1759   // Populate return_nodes
1760   DCHECK(tf_results->return_nodes.empty());
1761   tf_results->return_nodes.resize(results.return_nodes.size());
1762   for (int i = 0; i < results.return_nodes.size(); ++i) {
1763     tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1764   }
1765 
1766   // Populate missing unused map keys
1767   DCHECK(tf_results->missing_unused_key_names.empty());
1768   DCHECK(tf_results->missing_unused_key_indexes.empty());
1769   DCHECK(tf_results->missing_unused_key_names_data.empty());
1770 
1771   size_t size = results.missing_unused_input_map_keys.size();
1772   tf_results->missing_unused_key_names.resize(size);
1773   tf_results->missing_unused_key_indexes.resize(size);
1774 
1775   for (int i = 0; i < size; ++i) {
1776     TensorId id = results.missing_unused_input_map_keys[i];
1777     tf_results->missing_unused_key_names_data.emplace_back(id.first);
1778     tf_results->missing_unused_key_names[i] =
1779         tf_results->missing_unused_key_names_data.back().c_str();
1780     tf_results->missing_unused_key_indexes[i] = id.second;
1781   }
1782 }
1783 
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1784 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1785     TF_Graph* graph, const TF_Buffer* graph_def,
1786     const TF_ImportGraphDefOptions* options, TF_Status* status) {
1787   GraphDef def;
1788   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1789                                        graph_def->length)) {
1790     status->status = InvalidArgument("Invalid GraphDef");
1791     return nullptr;
1792   }
1793   auto results = new TF_ImportGraphDefResults();
1794   mutex_lock l(graph->mu);
1795   GraphImportGraphDefLocked(graph, def, options, results, status);
1796   if (!status->status.ok()) {
1797     delete results;
1798     return nullptr;
1799   }
1800   return results;
1801 }
1802 
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1803 void TF_GraphImportGraphDefWithReturnOutputs(
1804     TF_Graph* graph, const TF_Buffer* graph_def,
1805     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1806     int num_return_outputs, TF_Status* status) {
1807   if (num_return_outputs != options->opts.return_tensors.size()) {
1808     status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1809                                      options->opts.return_tensors.size(),
1810                                      ", got ", num_return_outputs);
1811     return;
1812   }
1813   if (num_return_outputs > 0 && return_outputs == nullptr) {
1814     status->status = InvalidArgument(
1815         "'return_outputs' must be preallocated to length ", num_return_outputs);
1816     return;
1817   }
1818   GraphDef def;
1819   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1820                                        graph_def->length)) {
1821     status->status = InvalidArgument("Invalid GraphDef");
1822     return;
1823   }
1824   TF_ImportGraphDefResults results;
1825   mutex_lock l(graph->mu);
1826   GraphImportGraphDefLocked(graph, def, options, &results, status);
1827   DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1828   memcpy(return_outputs, results.return_tensors.data(),
1829          num_return_outputs * sizeof(TF_Output));
1830 }
1831 
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1832 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1833                             const TF_ImportGraphDefOptions* options,
1834                             TF_Status* status) {
1835   TF_ImportGraphDefResults* results =
1836       TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1837   TF_DeleteImportGraphDefResults(results);
1838 }
1839 
1840 // While loop functions -------------------------------------------------------
1841 
1842 namespace {
1843 
1844 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1845 
1846 // Creates a placeholder representing an input to the cond or body graph.
1847 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1848 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1849                  TF_Output* input, TF_Status* status) {
1850   TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1851   TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1852   // TODO(skyewm): set placeholder shape
1853   TF_Operation* oper = TF_FinishOperation(desc, status);
1854   if (!status->status.ok()) return false;
1855   *input = {oper, 0};
1856   return true;
1857 }
1858 
1859 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1860 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
1861 // will be prepended to copied node names. `control_deps` are nodes in
1862 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1863 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1864 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1865 tensorflow::Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1866                  tensorflow::ShapeRefiner* dst_refiner,
1867                  const TF_Output* src_inputs,
1868                  const std::vector<tensorflow::Output>& dst_inputs,
1869                  const string& prefix,
1870                  const std::vector<tensorflow::Operation>& control_deps,
1871                  const TF_Output* nodes_to_return, int nreturn_nodes,
1872                  std::vector<tensorflow::Output>* return_nodes) {
1873   DCHECK(return_nodes != nullptr);
1874   GraphDef gdef;
1875   src_graph->ToGraphDef(&gdef);
1876 
1877   tensorflow::ImportGraphDefOptions opts;
1878   opts.prefix = prefix;
1879 
1880   for (int i = 0; i < dst_inputs.size(); ++i) {
1881     opts.input_map[ToTensorId(src_inputs[i])] =
1882         TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1883   }
1884   opts.skip_mapped_nodes = true;
1885 
1886   for (const tensorflow::Operation& op : control_deps) {
1887     opts.control_dependencies.push_back(op.node()->name());
1888   }
1889 
1890   for (int i = 0; i < nreturn_nodes; ++i) {
1891     opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1892   }
1893 
1894   // TODO(skyewm): change to OutputTensor
1895   tensorflow::ImportGraphDefResults results;
1896   TF_RETURN_IF_ERROR(
1897       ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1898 
1899   for (const auto& pair : results.return_tensors) {
1900     return_nodes->emplace_back(pair.first, pair.second);
1901   }
1902   return ::tensorflow::OkStatus();
1903 }
1904 
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1905 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1906   if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1907       params.cond_graph->parent == nullptr ||
1908       params.cond_graph->parent != params.body_graph->parent ||
1909       params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1910       params.ninputs <= 0 || params.cond_inputs == nullptr ||
1911       params.body_inputs == nullptr || params.body_outputs == nullptr) {
1912     s->status = InvalidArgument(
1913         "TF_WhileParams must be created by successful TF_NewWhile() call");
1914     return false;
1915   }
1916   return true;
1917 }
1918 
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1919 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1920   if (params.cond_output.oper == nullptr) {
1921     s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1922     return false;
1923   }
1924   for (int i = 0; i < params.ninputs; ++i) {
1925     if (params.body_outputs[i].oper == nullptr) {
1926       s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1927                                   "field isn't set");
1928       return false;
1929     }
1930   }
1931   if (params.name == nullptr) {
1932     s->status = InvalidArgument("TF_WhileParams `name` field is null");
1933     return false;
1934   }
1935   return true;
1936 }
1937 
1938 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1939 
FreeWhileResources(const TF_WhileParams * params)1940 void FreeWhileResources(const TF_WhileParams* params) {
1941   TF_DeleteGraph(params->cond_graph);
1942   TF_DeleteGraph(params->body_graph);
1943   delete[] params->cond_inputs;
1944   delete[] params->body_inputs;
1945   delete[] params->body_outputs;
1946 }
1947 
EmptyWhileParams()1948 TF_WhileParams EmptyWhileParams() {
1949   return {0,       nullptr, nullptr, {nullptr, 0},
1950           nullptr, nullptr, nullptr, nullptr};
1951 }
1952 
1953 }  // namespace
1954 
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1955 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1956                            TF_Status* status) {
1957 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1958   status->status = tensorflow::errors::Unimplemented(
1959       "Creating while loops is not supported on mobile. File a bug at "
1960       "https://github.com/tensorflow/tensorflow/issues if this feature is "
1961       "important to you");
1962   return EmptyWhileParams();
1963 #else
1964   if (ninputs == 0) {
1965     status->status =
1966         InvalidArgument("TF_NewWhile() must be passed at least one input");
1967     return EmptyWhileParams();
1968   }
1969 
1970   TF_Graph* cond_graph = TF_NewGraph();
1971   TF_Graph* body_graph = TF_NewGraph();
1972   cond_graph->parent = g;
1973   cond_graph->parent_inputs = inputs;
1974   body_graph->parent = g;
1975   body_graph->parent_inputs = inputs;
1976 
1977   TF_Output* cond_inputs = new TF_Output[ninputs];
1978   TF_Output cond_output = {nullptr, -1};
1979   TF_Output* body_inputs = new TF_Output[ninputs];
1980   TF_Output* body_outputs = new TF_Output[ninputs];
1981   for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1982   const char* name = nullptr;
1983 
1984   for (int i = 0; i < ninputs; ++i) {
1985     // TODO(skyewm): prefix names with underscore (requires some plumbing)
1986     if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1987                      &cond_inputs[i], status)) {
1988       break;
1989     }
1990     if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1991                      &body_inputs[i], status)) {
1992       break;
1993     }
1994   }
1995 
1996   TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
1997                            body_graph, body_inputs, body_outputs, name};
1998 
1999   if (!status->status.ok()) {
2000     FreeWhileResources(&params);
2001     return EmptyWhileParams();
2002   }
2003   return params;
2004 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2005 }
2006 
2007 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2008 namespace {
2009 
2010 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2011 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2012                           TF_Output* outputs) {
2013   if (!ValidateInputWhileParams(*params, status)) return;
2014 
2015   TF_Graph* parent = params->cond_graph->parent;
2016   TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2017   int num_loop_vars = params->ninputs;
2018 
2019   mutex_lock l(parent->mu);
2020 
2021   // 'cond_fn' copies the cond graph into the parent graph.
2022   tensorflow::ops::CondGraphBuilderFn cond_fn =
2023       [params, parent](const tensorflow::Scope& scope,
2024                        const std::vector<tensorflow::Output>& inputs,
2025                        tensorflow::Output* output) {
2026         DCHECK_EQ(scope.graph(), &parent->graph);
2027         std::vector<tensorflow::Output> cond_output;
2028         TF_RETURN_IF_ERROR(CopyGraph(
2029             &params->cond_graph->graph, &parent->graph, &parent->refiner,
2030             params->cond_inputs, inputs, scope.impl()->name(),
2031             scope.impl()->control_deps(), &params->cond_output,
2032             /* nreturn_nodes */ 1, &cond_output));
2033         *output = cond_output[0];
2034         return ::tensorflow::OkStatus();
2035       };
2036 
2037   // 'body_fn' copies the body graph into the parent graph.
2038   tensorflow::ops::BodyGraphBuilderFn body_fn =
2039       [params, parent, num_loop_vars](
2040           const tensorflow::Scope& scope,
2041           const std::vector<tensorflow::Output>& inputs,
2042           std::vector<tensorflow::Output>* outputs) {
2043         DCHECK_EQ(scope.graph(), &parent->graph);
2044         TF_RETURN_IF_ERROR(
2045             CopyGraph(&params->body_graph->graph, &parent->graph,
2046                       &parent->refiner, params->body_inputs, inputs,
2047                       scope.impl()->name(), scope.impl()->control_deps(),
2048                       params->body_outputs, num_loop_vars, outputs));
2049         return ::tensorflow::OkStatus();
2050       };
2051 
2052   // Create the while loop using an internal scope.
2053   tensorflow::Scope scope =
2054       NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2055           .NewSubScope(params->name);
2056 
2057   const int first_new_node_id = parent->graph.num_node_ids();
2058 
2059   tensorflow::OutputList loop_outputs;
2060   status->status = tensorflow::ops::BuildWhileLoop(
2061       scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2062       body_fn, params->name, &loop_outputs);
2063 
2064   // Update name_map with newly-created ops.
2065   // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2066   // a bad status. Once we fix this, we may want to return early instead of
2067   // executing the following code.
2068   for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2069     Node* new_node = parent->graph.FindNodeId(i);
2070     if (new_node == nullptr) continue;
2071     parent->name_map[new_node->name()] = new_node;
2072   }
2073 
2074   // Populate 'outputs'.
2075   DCHECK_LE(loop_outputs.size(), num_loop_vars);
2076   for (int i = 0; i < loop_outputs.size(); ++i) {
2077     outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2078   }
2079 }
2080 
2081 }  // namespace
2082 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2083 
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2084 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2085                     TF_Output* outputs) {
2086 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2087   status->status = tensorflow::errors::Unimplemented(
2088       "Creating while loops is not supported on mobile. File a bug at "
2089       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2090       "important to you");
2091 #else
2092   // If it appears the caller created or modified `params`, don't free resources
2093   if (!ValidateConstWhileParams(*params, status)) return;
2094   TF_FinishWhileHelper(params, status, outputs);
2095   FreeWhileResources(params);
2096 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2097 }
2098 
TF_AbortWhile(const TF_WhileParams * params)2099 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2100 
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2101 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2102                      TF_Output* dx, TF_Status* status, TF_Output* dy) {
2103   TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2104 }
2105 
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2106 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2107                                int ny, TF_Output* x, int nx, TF_Output* dx,
2108                                TF_Status* status, TF_Output* dy) {
2109 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2110   status->status = tensorflow::errors::Unimplemented(
2111       "Adding gradients is not supported on mobile. File a bug at "
2112       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2113       "important to you");
2114 #else
2115   std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2116   std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2117   std::vector<tensorflow::Output> dy_arg;
2118 
2119   {
2120     // We need to hold on to the lock while we have a scope that uses TF_Graph.
2121     mutex_lock graph_lock(g->mu);
2122 
2123     const int first_new_node_id = g->graph.num_node_ids();
2124 
2125     string prefix_cmp;
2126     const char* child_scope_name;
2127     if (prefix == nullptr) {
2128       child_scope_name = "gradients";
2129     } else {
2130       prefix_cmp = string(prefix) + "/";
2131       // The operation should fail if the provided name prefix has already been
2132       // used in this graph
2133       for (const auto& pair : g->name_map) {
2134         const string& name = pair.first;
2135         if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2136           status->status = InvalidArgument(
2137               "prefix [", prefix,
2138               "] conflicts with existing node in the graph named [", name, "]");
2139           return;
2140         }
2141       }
2142       child_scope_name = prefix;
2143     }
2144     tensorflow::Scope scope =
2145         NewInternalScope(&g->graph, &status->status, &g->refiner)
2146             .NewSubScope(child_scope_name);
2147 
2148     if (dx != nullptr) {
2149       std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2150       status->status =
2151           AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2152     } else {
2153       status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2154     }
2155 
2156     // Update g->name_map with the name_map from the scope, which will contain
2157     // the new gradient ops.
2158     for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2159       Node* n = g->graph.FindNodeId(i);
2160       if (n == nullptr) continue;
2161 
2162       // Adding the gradients to the graph can alter the prefix to prevent
2163       // name collisions only if this prefix has not been provided explicitly
2164       // by the user. If it was provided, assert that it remained intact.
2165       if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2166         status->status = tensorflow::errors::Internal(
2167             "BUG: The gradients prefix have been unexpectedly altered when "
2168             "adding the nodes to the graph. This is a bug. Please file an "
2169             "issue at https://github.com/tensorflow/tensorflow/issues.");
2170         return;
2171       }
2172       // We have a convoluted scheme here: Using the C++ graph construction API
2173       // to add potentially many nodes to the graph without running the checks
2174       // (such as uniqueness of the names of nodes) we run with other functions
2175       // that add a node to the graph (like TF_FinishOperation).
2176       if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2177         status->status = tensorflow::errors::Internal(
2178             "BUG: The API allowed construction of a graph with duplicate node "
2179             "names (",
2180             n->name(),
2181             "). This is a bug. Please file an issue at "
2182             "https://github.com/tensorflow/tensorflow/issues.");
2183       }
2184     }
2185   }
2186 
2187   // Unpack the results from grad_outputs_arg.
2188   TFOutputsFromOutputs(dy_arg, dy);
2189 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2190 }
2191 
2192 // TF_Session functions ----------------------------------------------
2193 
TF_Session(tensorflow::Session * s,TF_Graph * g)2194 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2195     : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2196 
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2197 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2198                           TF_Status* status) {
2199   Session* session;
2200   status->status = NewSession(opt->options, &session);
2201   if (status->status.ok()) {
2202     TF_Session* new_session = new TF_Session(session, graph);
2203     if (graph != nullptr) {
2204       mutex_lock l(graph->mu);
2205       graph->sessions[new_session] = "";
2206     }
2207     return new_session;
2208   } else {
2209     LOG(ERROR) << status->status;
2210     DCHECK_EQ(nullptr, session);
2211     return nullptr;
2212   }
2213 }
2214 
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2215 TF_Session* TF_LoadSessionFromSavedModel(
2216     const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2217     const char* export_dir, const char* const* tags, int tags_len,
2218     TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2219 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2220 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2221 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2222   status->status = tensorflow::errors::Unimplemented(
2223       "Loading a SavedModel is not supported on mobile. File a bug at "
2224       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2225       "important to you");
2226   return nullptr;
2227 #else
2228   mutex_lock l(graph->mu);
2229   if (!graph->name_map.empty()) {
2230     status->status = InvalidArgument("Graph is non-empty.");
2231     return nullptr;
2232   }
2233 
2234   RunOptions run_options_proto;
2235   if (run_options != nullptr && !run_options_proto.ParseFromArray(
2236                                     run_options->data, run_options->length)) {
2237     status->status = InvalidArgument("Unparseable RunOptions proto");
2238     return nullptr;
2239   }
2240 
2241   std::unordered_set<string> tag_set;
2242   for (int i = 0; i < tags_len; i++) {
2243     tag_set.insert(string(tags[i]));
2244   }
2245 
2246   tensorflow::SavedModelBundle bundle;
2247   status->status =
2248       tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2249                                  export_dir, tag_set, &bundle);
2250   if (!status->status.ok()) return nullptr;
2251 
2252   // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2253   // extends using GraphDefs. The Graph instance is different, but equivalent
2254   // to the one used to create the session.
2255   //
2256   // TODO(jhseu): When Session is modified to take Graphs instead of
2257   // GraphDefs, return the Graph generated in LoadSavedModel().
2258   TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2259   TF_ImportGraphDefResults results;
2260   GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2261                             import_opts, &results, status);
2262   TF_DeleteImportGraphDefOptions(import_opts);
2263   if (!status->status.ok()) return nullptr;
2264 
2265   if (meta_graph_def != nullptr) {
2266     status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2267     if (!status->status.ok()) return nullptr;
2268   }
2269 
2270   TF_Session* session = new TF_Session(bundle.session.release(), graph);
2271 
2272   graph->sessions[session] = "";
2273   session->last_num_graph_nodes = graph->graph.num_node_ids();
2274   return session;
2275 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2276 }
2277 
TF_CloseSession(TF_Session * s,TF_Status * status)2278 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2279   status->status = s->session->Close();
2280 }
2281 
TF_DeleteSession(TF_Session * s,TF_Status * status)2282 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2283   status->status = ::tensorflow::OkStatus();
2284   if (s == nullptr) return;
2285   TF_Graph* const graph = s->graph;
2286   if (graph != nullptr) {
2287     graph->mu.lock();
2288     graph->sessions.erase(s);
2289     const bool del = graph->delete_requested && graph->sessions.empty();
2290     graph->mu.unlock();
2291     if (del) delete graph;
2292   }
2293   delete s->session;
2294   delete s;
2295 }
2296 
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2297 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2298                    const TF_Output* inputs, TF_Tensor* const* input_values,
2299                    int ninputs, const TF_Output* outputs,
2300                    TF_Tensor** output_values, int noutputs,
2301                    const TF_Operation* const* target_opers, int ntargets,
2302                    TF_Buffer* run_metadata, TF_Status* status) {
2303   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2304   // directly, instead of requiring us to serialize to a GraphDef and
2305   // call Session::Extend().
2306   if (session->extend_before_run &&
2307       !ExtendSessionGraphHelper(session, status)) {
2308     return;
2309   }
2310 
2311   TF_Run_Setup(noutputs, output_values, status);
2312 
2313   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2314   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2315   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2316   for (int i = 0; i < ninputs; ++i) {
2317     input_pairs[i].first = OutputName(inputs[i]);
2318   }
2319 
2320   // Convert from TF_Output to string names.
2321   std::vector<string> output_names(noutputs);
2322   for (int i = 0; i < noutputs; ++i) {
2323     output_names[i] = OutputName(outputs[i]);
2324   }
2325 
2326   // Convert from TF_Operation* to string names.
2327   std::vector<string> target_names(ntargets);
2328   for (int i = 0; i < ntargets; ++i) {
2329     target_names[i] = target_opers[i]->node.name();
2330   }
2331 
2332   // Actually run.
2333   TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2334                 output_names, output_values, target_names, run_metadata,
2335                 status);
2336 }
2337 
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2338 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2339                          int ninputs, const TF_Output* outputs, int noutputs,
2340                          const TF_Operation* const* target_opers, int ntargets,
2341                          const char** handle, TF_Status* status) {
2342   *handle = nullptr;
2343 
2344   if (session->extend_before_run &&
2345       !ExtendSessionGraphHelper(session, status)) {
2346     return;
2347   }
2348 
2349   std::vector<string> input_names(ninputs);
2350   for (int i = 0; i < ninputs; ++i) {
2351     input_names[i] = OutputName(inputs[i]);
2352   }
2353 
2354   std::vector<string> output_names(noutputs);
2355   for (int i = 0; i < noutputs; ++i) {
2356     output_names[i] = OutputName(outputs[i]);
2357   }
2358 
2359   std::vector<string> target_names(ntargets);
2360   for (int i = 0; i < ntargets; ++i) {
2361     target_names[i] = target_opers[i]->node.name();
2362   }
2363 
2364   string new_handle;
2365   status->status = session->session->PRunSetup(input_names, output_names,
2366                                                target_names, &new_handle);
2367   if (status->status.ok()) {
2368     char* buf = new char[new_handle.size() + 1];
2369     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2370     *handle = buf;
2371   }
2372 }
2373 
TF_DeletePRunHandle(const char * handle)2374 void TF_DeletePRunHandle(const char* handle) {
2375   delete[] handle;
2376   // TODO(suharshs): Free up any resources held by the partial run state.
2377 }
2378 
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2379 void TF_SessionPRun(TF_Session* session, const char* handle,
2380                     const TF_Output* inputs, TF_Tensor* const* input_values,
2381                     int ninputs, const TF_Output* outputs,
2382                     TF_Tensor** output_values, int noutputs,
2383                     const TF_Operation* const* target_opers, int ntargets,
2384                     TF_Status* status) {
2385   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2386   // directly, instead of requiring us to serialize to a GraphDef and
2387   // call Session::Extend().
2388   if (session->extend_before_run &&
2389       !ExtendSessionGraphHelper(session, status)) {
2390     return;
2391   }
2392 
2393   TF_Run_Setup(noutputs, output_values, status);
2394 
2395   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2396   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2397   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2398   for (int i = 0; i < ninputs; ++i) {
2399     input_pairs[i].first = OutputName(inputs[i]);
2400   }
2401 
2402   // Convert from TF_Output to string names.
2403   std::vector<string> output_names(noutputs);
2404   for (int i = 0; i < noutputs; ++i) {
2405     output_names[i] = OutputName(outputs[i]);
2406   }
2407 
2408   // Convert from TF_Operation* to string names.
2409   std::vector<string> target_names(ntargets);
2410   for (int i = 0; i < ntargets; ++i) {
2411     target_names[i] = target_opers[i]->node.name();
2412   }
2413 
2414   TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2415                 output_values, target_names, nullptr, status);
2416 }
2417 
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2418 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2419                                      TF_Tensor** result, TF_Status* status) {
2420   *result = nullptr;
2421   mutex_lock l(graph->mu);
2422   OutputTensor tensor(&output.oper->node, output.index);
2423   bool evaluated;
2424   Tensor result_tensor;
2425   status->status = EvaluateConstantTensor(
2426       tensor, graph->refiner, *graph->graph.op_registry(),
2427       graph->graph.versions().producer(), &evaluated, &result_tensor);
2428   if (evaluated) {
2429     DCHECK(status->status.ok());
2430     *result = TF_TensorFromTensor(result_tensor, &status->status);
2431     if (!status->status.ok()) evaluated = false;
2432   }
2433   return evaluated;
2434 }
2435 
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2436 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2437   tensorflow::OpList op_list;
2438   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2439     status->status = InvalidArgument("Unparseable OpList");
2440     return nullptr;
2441   }
2442   status->status = ::tensorflow::OkStatus();
2443   return new TF_ApiDefMap(op_list);
2444 }
2445 
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2446 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2447 
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2448 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2449                      size_t text_len, TF_Status* status) {
2450 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2451   status->status = tensorflow::errors::Unimplemented(
2452       "ApiDefMap is not supported on mobile.");
2453 #else
2454   mutex_lock l(api_def_map->lock);
2455   if (api_def_map->update_docs_called) {
2456     status->status = FailedPrecondition(
2457         "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2458         "called.");
2459     return;
2460   }
2461   string api_def_text(text, text_len);
2462   status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2463 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2464 }
2465 
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2466 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2467                            size_t name_len, TF_Status* status) {
2468 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2469   status->status = tensorflow::errors::Unimplemented(
2470       "ApiDefMap is not supported on mobile.");
2471   return nullptr;
2472 #else
2473   mutex_lock l(api_def_map->lock);
2474   if (!api_def_map->update_docs_called) {
2475     api_def_map->api_def_map.UpdateDocs();
2476     api_def_map->update_docs_called = true;
2477   }
2478   string name_str(name, name_len);
2479   const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2480   if (api_def == nullptr) {
2481     return nullptr;
2482   }
2483 
2484   TF_Buffer* ret = TF_NewBuffer();
2485   status->status = MessageToBuffer(*api_def, ret);
2486   if (!status->status.ok()) {
2487     TF_DeleteBuffer(ret);
2488     return nullptr;
2489   }
2490   return ret;
2491 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2492 }
2493 
TF_GetAllRegisteredKernels(TF_Status * status)2494 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2495   tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2496   TF_Buffer* ret = TF_NewBuffer();
2497   status->status = MessageToBuffer(kernel_list, ret);
2498   if (!status->status.ok()) {
2499     TF_DeleteBuffer(ret);
2500     return nullptr;
2501   }
2502   return ret;
2503 }
2504 
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2505 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2506   tensorflow::KernelList kernel_list =
2507       tensorflow::GetRegisteredKernelsForOp(name);
2508   TF_Buffer* ret = TF_NewBuffer();
2509   status->status = MessageToBuffer(kernel_list, ret);
2510   if (!status->status.ok()) {
2511     TF_DeleteBuffer(ret);
2512     return nullptr;
2513   }
2514   return ret;
2515 }
2516 
TF_UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)2517 void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2518                    TF_Status* status) {
2519   using tensorflow::RecordMutation;
2520   mutex_lock l(graph->mu);
2521   tensorflow::shape_inference::InferenceContext* ic =
2522       graph->refiner.GetContext(&new_src.oper->node);
2523 
2524   if (ic->num_outputs() <= new_src.index) {
2525     status->status = tensorflow::errors::OutOfRange(
2526         "Cannot update edge. Output index [", new_src.index,
2527         "] is greater than the number of total outputs [", ic->num_outputs(),
2528         "].");
2529     return;
2530   }
2531   tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
2532 
2533   tensorflow::shape_inference::InferenceContext* ic_dst =
2534       graph->refiner.GetContext(&dst.oper->node);
2535   if (ic_dst->num_inputs() <= dst.index) {
2536     status->status = tensorflow::errors::OutOfRange(
2537         "Cannot update edge. Input index [", dst.index,
2538         "] is greater than the number of total inputs [", ic_dst->num_inputs(),
2539         "].");
2540     return;
2541   }
2542   if (!ic_dst->MergeInput(dst.index, shape)) {
2543     status->status = tensorflow::errors::InvalidArgument(
2544         "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
2545         " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
2546     return;
2547   }
2548   status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
2549                                            &dst.oper->node, dst.index);
2550 
2551   if (TF_GetCode(status) == TF_OK) {
2552     // This modification only updates the destination node for
2553     // the purposes of running this graph in a session. Thus, we don't
2554     // record the source node as being modified.
2555     RecordMutation(graph, *dst.oper, "updating input tensor");
2556   }
2557 }
2558 
2559 // TF_Server functions ----------------------------------------------
2560 
2561 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2562 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2563     : target(server->target()), server(std::move(server)) {}
2564 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2565 
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2566 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2567                         TF_Status* status) {
2568 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2569   status->status = tensorflow::errors::Unimplemented(
2570       "Server functionality is not supported on mobile");
2571   return nullptr;
2572 #else
2573   tensorflow::ServerDef server_def;
2574   if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2575     status->status = InvalidArgument(
2576         "Could not parse provided bytes into a ServerDef protocol buffer");
2577     return nullptr;
2578   }
2579 
2580   std::unique_ptr<tensorflow::ServerInterface> out_server;
2581   status->status = tensorflow::NewServer(server_def, &out_server);
2582   if (!status->status.ok()) return nullptr;
2583 
2584   return new TF_Server(std::move(out_server));
2585 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2586 }
2587 
TF_ServerStart(TF_Server * server,TF_Status * status)2588 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2589 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2590   status->status = tensorflow::errors::Unimplemented(
2591       "Server functionality is not supported on mobile");
2592 #else
2593   status->status = server->server->Start();
2594 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2595 }
2596 
TF_ServerStop(TF_Server * server,TF_Status * status)2597 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2598 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2599   status->status = tensorflow::errors::Unimplemented(
2600       "Server functionality is not supported on mobile");
2601 #else
2602   status->status = server->server->Stop();
2603 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2604 }
2605 
TF_ServerJoin(TF_Server * server,TF_Status * status)2606 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2607 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2608   status->status = tensorflow::errors::Unimplemented(
2609       "Server functionality is not supported on mobile");
2610 #else
2611   status->status = server->server->Join();
2612 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2613 }
2614 
TF_ServerTarget(TF_Server * server)2615 const char* TF_ServerTarget(TF_Server* server) {
2616 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2617   return nullptr;
2618 #else
2619   return server->target.c_str();
2620 #endif
2621 }
2622 
TF_DeleteServer(TF_Server * server)2623 void TF_DeleteServer(TF_Server* server) {
2624 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2625   delete server;
2626 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2627 }
2628 
TF_RegisterLogListener(void (* listener)(const char *))2629 void TF_RegisterLogListener(void (*listener)(const char*)) {
2630 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2631   tensorflow::logging::RegisterListener(listener);
2632 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2633 }
2634 
TF_RegisterFilesystemPlugin(const char * plugin_filename,TF_Status * status)2635 void TF_RegisterFilesystemPlugin(const char* plugin_filename,
2636                                  TF_Status* status) {
2637 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2638   status->status = tensorflow::errors::Unimplemented(
2639       "FileSystem plugin functionality is not supported on mobile");
2640 #else
2641   status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
2642 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2643 }
2644 
2645 }  // end extern "C"
2646