• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h"  // NOLINT
26 
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
29 #include "tensorflow/cc/framework/gradients.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/cc/framework/scope_internal.h"
32 #include "tensorflow/cc/ops/while_loop.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/framework/logging.h"
36 #include "tensorflow/core/framework/op_gen_lib.h"
37 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 #include "tensorflow/c/c_api_internal.h"
39 #include "tensorflow/c/tf_status_internal.h"
40 #include "tensorflow/c/tf_tensor.h"
41 #include "tensorflow/core/common_runtime/device_mgr.h"
42 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/common_runtime/shape_refiner.h"
45 #include "tensorflow/core/framework/allocation_description.pb.h"
46 #include "tensorflow/core/framework/kernel_def.pb.h"
47 #include "tensorflow/core/framework/log_memory.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/partial_tensor_shape.h"
51 #include "tensorflow/core/framework/tensor.h"
52 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/framework/versions.pb.h"
57 #include "tensorflow/core/graph/graph.h"
58 #include "tensorflow/core/graph/node_builder.h"
59 #include "tensorflow/core/graph/validate.h"
60 #include "tensorflow/core/lib/gtl/array_slice.h"
61 #include "tensorflow/core/platform/coding.h"
62 #include "tensorflow/core/platform/errors.h"
63 #include "tensorflow/core/platform/mem.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/protobuf.h"
66 #include "tensorflow/core/platform/status.h"
67 #include "tensorflow/core/platform/str_util.h"
68 #include "tensorflow/core/platform/strcat.h"
69 #include "tensorflow/core/platform/stringpiece.h"
70 #include "tensorflow/core/platform/thread_annotations.h"
71 #include "tensorflow/core/platform/types.h"
72 #include "tensorflow/core/public/session.h"
73 #include "tensorflow/core/public/version.h"
74 
75 // The implementation below is at the top level instead of the
76 // brain namespace because we are defining 'extern "C"' functions.
77 using tensorflow::AllocationDescription;
78 using tensorflow::DataType;
79 using tensorflow::ExtendSessionGraphHelper;
80 using tensorflow::Graph;
81 using tensorflow::GraphDef;
82 using tensorflow::mutex_lock;
83 using tensorflow::NameRangeMap;
84 using tensorflow::NameRangesForNode;
85 using tensorflow::NewSession;
86 using tensorflow::Node;
87 using tensorflow::NodeBuilder;
88 using tensorflow::NodeDef;
89 using tensorflow::OpDef;
90 using tensorflow::OpRegistry;
91 using tensorflow::OutputTensor;
92 using tensorflow::PartialTensorShape;
93 using tensorflow::RunMetadata;
94 using tensorflow::RunOptions;
95 using tensorflow::Session;
96 using tensorflow::Status;
97 using tensorflow::string;
98 using tensorflow::Tensor;
99 using tensorflow::TensorBuffer;
100 using tensorflow::TensorId;
101 using tensorflow::TensorShape;
102 using tensorflow::TensorShapeProto;
103 using tensorflow::VersionDef;
104 using tensorflow::errors::FailedPrecondition;
105 using tensorflow::errors::InvalidArgument;
106 using tensorflow::gtl::ArraySlice;
107 using tensorflow::strings::StrCat;
108 
109 extern "C" {
110 
111 // --------------------------------------------------------------------------
TF_Version()112 const char* TF_Version() { return TF_VERSION_STRING; }
113 
114 // --------------------------------------------------------------------------
115 
116 // --------------------------------------------------------------------------
TF_NewSessionOptions()117 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)118 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
119 
TF_SetTarget(TF_SessionOptions * options,const char * target)120 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
121   options->options.target = target;
122 }
123 
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)124 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
125                   size_t proto_len, TF_Status* status) {
126   if (!options->options.config.ParseFromArray(proto, proto_len)) {
127     status->status = InvalidArgument("Unparseable ConfigProto");
128   }
129 }
130 // --------------------------------------------------------------------------
TF_NewBuffer()131 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
132 
TF_NewBufferFromString(const void * proto,size_t proto_len)133 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
134   void* copy = tensorflow::port::Malloc(proto_len);
135   memcpy(copy, proto, proto_len);
136 
137   TF_Buffer* buf = new TF_Buffer;
138   buf->data = copy;
139   buf->length = proto_len;
140   buf->data_deallocator = [](void* data, size_t length) {
141     tensorflow::port::Free(data);
142   };
143   return buf;
144 }
145 
TF_DeleteBuffer(TF_Buffer * buffer)146 void TF_DeleteBuffer(TF_Buffer* buffer) {
147   if (buffer == nullptr) return;
148   if (buffer->data_deallocator != nullptr) {
149     (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
150                                 buffer->length);
151   }
152   delete buffer;
153 }
154 
TF_GetBuffer(TF_Buffer * buffer)155 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
156 
TF_TensorFromProto(const TF_Buffer * from,TF_Tensor * to,TF_Status * status)157 void TF_TensorFromProto(const TF_Buffer* from, TF_Tensor* to,
158                         TF_Status* status) {
159   TF_SetStatus(status, TF_OK, "");
160   tensorflow::TensorProto from_tensor_proto;
161   status->status = BufferToMessage(from, &from_tensor_proto);
162   if (!status->status.ok()) {
163     return;
164   }
165   status->status =
166       tensorflow::down_cast<tensorflow::TensorInterface*>(to->tensor)
167           ->FromProto(from_tensor_proto);
168 }
169 // --------------------------------------------------------------------------
170 
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)171 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
172                                               TF_Status* status) {
173   Session* session;
174   status->status = NewSession(opt->options, &session);
175   if (status->status.ok()) {
176     return new TF_DeprecatedSession({session});
177   } else {
178     DCHECK_EQ(nullptr, session);
179     return nullptr;
180   }
181 }
182 
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)183 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
184   status->status = s->session->Close();
185 }
186 
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)187 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
188   status->status = Status::OK();
189   if (s == nullptr) return;
190   delete s->session;
191   delete s;
192 }
193 
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)194 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
195                     size_t proto_len, TF_Status* status) {
196   GraphDef g;
197   if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
198     status->status = InvalidArgument("Invalid GraphDef");
199     return;
200   }
201   status->status = s->session->Extend(g);
202 }
203 
204 }  // end extern "C"
205 
206 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)207 static void TF_Reset_Helper(const TF_SessionOptions* opt,
208                             const char** containers, int ncontainers,
209                             TF_Status* status) {
210   std::vector<string> container_names(ncontainers);
211   for (int i = 0; i < ncontainers; ++i) {
212     container_names[i] = containers[i];
213   }
214 
215   status->status = Reset(opt->options, container_names);
216 }
217 
218 extern "C" {
219 
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)220 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
221               int ncontainers, TF_Status* status) {
222   TF_Reset_Helper(opt, containers, ncontainers, status);
223 }
224 
225 }  // end extern "C"
226 
227 namespace tensorflow {
228 
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)229 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
230                        TF_Buffer* out) {
231   if (out->data != nullptr) {
232     return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
233   }
234   const size_t proto_size = in.ByteSizeLong();
235   void* buf = port::Malloc(proto_size);
236   if (buf == nullptr) {
237     return tensorflow::errors::ResourceExhausted(
238         "Failed to allocate memory to serialize message of type '",
239         in.GetTypeName(), "' and size ", proto_size);
240   }
241   if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
242     port::Free(buf);
243     return InvalidArgument("Unable to serialize ", in.GetTypeName(),
244                            " protocol buffer, perhaps the serialized size (",
245                            proto_size, " bytes) is too large?");
246   }
247   out->data = buf;
248   out->length = proto_size;
249   out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
250   return Status::OK();
251 }
252 
BufferToMessage(const TF_Buffer * in,tensorflow::protobuf::MessageLite * out)253 Status BufferToMessage(const TF_Buffer* in,
254                        tensorflow::protobuf::MessageLite* out) {
255   if (in == nullptr || !out->ParseFromArray(in->data, in->length)) {
256     return errors::InvalidArgument("Unparseable ", out->GetTypeName(),
257                                    " proto");
258   }
259   return Status::OK();
260 }
261 
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)262 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
263                     const char* mutation_type) {
264   // If any session has already run this node_id, mark this session as
265   // unrunnable.
266   for (auto it : graph->sessions) {
267     mutex_lock session_lock(it.first->mu);
268     if (it.first->last_num_graph_nodes > op.node.id()) {
269       it.second = strings::StrCat(
270           "Operation '", op.node.DebugString(), "' was changed by ",
271           mutation_type,
272           " after it was run by a session. This mutation will have no effect, "
273           "and will trigger an error in the future. Either don't modify "
274           "nodes after running them or create a new session.");
275     }
276   }
277 }
278 
279 namespace {
280 
281 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)282 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
283     tensorflow::shape_inference::InferenceContext* ic, int num_dims,
284     const int64_t* dims) {
285   if (num_dims != -1) {
286     std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
287     dim_vec.reserve(num_dims);
288     for (int i = 0; i < num_dims; ++i) {
289       dim_vec.push_back(ic->MakeDim(dims[i]));
290     }
291     return ic->MakeShape(dim_vec);
292   } else {
293     return ic->UnknownShape();
294   }
295 }
296 
297 }  // namespace
298 
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)299 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
300                                            int num_shapes_and_types,
301                                            const int64_t** shapes,
302                                            const int* ranks,
303                                            const TF_DataType* types,
304                                            TF_Status* status) {
305   Node* node = &output.oper->node;
306 
307   mutex_lock l(graph->mu);
308   tensorflow::shape_inference::InferenceContext* ic =
309       graph->refiner.GetContext(node);
310   if (ic == nullptr) {
311     status->status =
312         InvalidArgument("Node ", node->name(), " was not found in the graph");
313     return;
314   }
315 
316   auto shape_and_type_vec =
317       std::vector<tensorflow::shape_inference::ShapeAndType>(
318           num_shapes_and_types);
319   for (int i = 0; i < num_shapes_and_types; ++i) {
320     tensorflow::shape_inference::ShapeHandle shape_handle =
321         ShapeHandleFromDims(ic, ranks[i], shapes[i]);
322     shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
323         shape_handle, static_cast<DataType>(types[i]));
324   }
325 
326   ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
327 }
328 
329 // Helpers for loading a TensorFlow plugin (a .so file).
330 Status LoadDynamicLibrary(const char* library_filename, void** result,
331                           const void** buf, size_t* len);
332 
333 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
334 // directly, instead of requiring us to serialize to a GraphDef and
335 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)336 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
337   if (session->graph != nullptr) {
338     // Take the graph lock before the session lock to avoid deadlock. This is
339     // safe since session->graph does not change.
340     session->graph->mu.lock();
341     mutex_lock session_lock(session->mu);
342     const Graph& graph = session->graph->graph;
343 
344     const string& mutation_warning = session->graph->sessions[session];
345     if (!mutation_warning.empty()) {
346       // TODO(b/74949947): turn this back into an error status
347       LOG(WARNING) << mutation_warning;
348       session->graph->sessions[session].clear();
349     }
350 
351     const auto num_nodes = graph.num_node_ids();
352     if (session->last_num_graph_nodes < num_nodes) {
353       // TODO(nolivia): check this on a subset of the graph instead of all of
354       // it.
355       status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
356       if (!status->status.ok()) {
357         session->graph->mu.unlock();
358         return false;
359       }
360 
361       GraphDef graph_def;
362       *graph_def.mutable_versions() = graph.versions();
363       // Fill graph_def with nodes with ids in the range
364       // [session->last_num_graph_nodes, num_nodes), that is the nodes
365       // added since the last TF_SessionRun() call.
366       for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
367         Node* const node = graph.FindNodeId(id);
368         if (node != nullptr && node->IsOp()) {
369           NodeDef* const node_def = graph_def.add_node();
370           *node_def = node->def();
371         }
372       }
373       *graph_def.mutable_library() = graph.flib_def().ToProto();
374       session->graph->mu.unlock();
375       status->status = session->session->Extend(std::move(graph_def));
376       if (!status->status.ok()) {
377         // Contract is we always delete input_values[i].
378         return false;
379       }
380       // Note: session->session is not modified if Extend() fails, so
381       // we only set last_num_graph_nodes if it succeeds.
382       session->last_num_graph_nodes = num_nodes;
383     } else {
384       session->graph->mu.unlock();
385     }
386   }
387   return true;
388 }
389 
390 }  // namespace tensorflow
391 
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)392 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
393                          TF_Status* status) {
394   status->status = Status::OK();
395   for (int i = 0; i < noutputs; ++i) {
396     c_outputs[i] = nullptr;
397   }
398 }
399 
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)400 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
401                           std::vector<std::pair<string, Tensor>>* input_pairs,
402                           TF_Status* status) {
403   const int ninputs = input_pairs->size();
404   for (int i = 0; i < ninputs; ++i) {
405     status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
406     if (!status->status.ok()) return false;
407   }
408   return true;
409 }
410 
411 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
412 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)413 static TF_Tensor* EmptyTensor(TF_DataType dtype,
414                               const tensorflow::TensorShape& shape) {
415   static char empty;
416   int64_t nelems = 1;
417   std::vector<tensorflow::int64> dims;
418   for (int i = 0; i < shape.dims(); ++i) {
419     dims.push_back(shape.dim_size(i));
420     nelems *= shape.dim_size(i);
421   }
422   CHECK_EQ(nelems, 0);
423   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
424                 "64-bit int types should match in size");
425   return TF_NewTensor(
426       dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
427       reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
428 }
429 
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)430 static void TF_Run_Helper(
431     Session* session, const char* handle, const TF_Buffer* run_options,
432     // Input tensors
433     const std::vector<std::pair<string, Tensor>>& input_pairs,
434     // Output tensors
435     const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
436     // Target nodes
437     const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
438     TF_Status* status) {
439   const int noutputs = output_tensor_names.size();
440   std::vector<Tensor> outputs(noutputs);
441   Status result;
442 
443   if (handle == nullptr) {
444     RunOptions run_options_proto;
445     if (run_options != nullptr && !run_options_proto.ParseFromArray(
446                                       run_options->data, run_options->length)) {
447       status->status = InvalidArgument("Unparseable RunOptions proto");
448       return;
449     }
450     if (run_metadata != nullptr && run_metadata->data != nullptr) {
451       status->status =
452           InvalidArgument("Passing non-empty run_metadata is invalid.");
453       return;
454     }
455 
456     RunMetadata run_metadata_proto;
457     result = session->Run(run_options_proto, input_pairs, output_tensor_names,
458                           target_oper_names, &outputs, &run_metadata_proto);
459 
460     // Serialize back to upstream client, who now owns the new buffer
461     if (run_metadata != nullptr) {
462       status->status = MessageToBuffer(run_metadata_proto, run_metadata);
463       if (!status->status.ok()) return;
464     }
465   } else {
466     // NOTE(zongheng): PRun does not support RunOptions yet.
467     result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
468   }
469   if (!result.ok()) {
470     status->status = result;
471     return;
472   }
473 
474   // Store results in c_outputs[]
475   for (int i = 0; i < noutputs; ++i) {
476     const Tensor& src = outputs[i];
477     if (!src.IsInitialized() || src.NumElements() == 0) {
478       c_outputs[i] =
479           EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
480       continue;
481     }
482     c_outputs[i] = TF_TensorFromTensor(src, &status->status);
483     if (!status->status.ok()) return;
484   }
485 }
486 
487 extern "C" {
488 
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)489 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
490             // Input tensors
491             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
492             // Output tensors
493             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
494             // Target nodes
495             const char** c_target_oper_names, int ntargets,
496             TF_Buffer* run_metadata, TF_Status* status) {
497   TF_Run_Setup(noutputs, c_outputs, status);
498   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
499   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
500   for (int i = 0; i < ninputs; ++i) {
501     input_pairs[i].first = c_input_names[i];
502   }
503   std::vector<string> output_names(noutputs);
504   for (int i = 0; i < noutputs; ++i) {
505     output_names[i] = c_output_names[i];
506   }
507   std::vector<string> target_oper_names(ntargets);
508   for (int i = 0; i < ntargets; ++i) {
509     target_oper_names[i] = c_target_oper_names[i];
510   }
511   TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
512                 c_outputs, target_oper_names, run_metadata, status);
513 }
514 
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)515 void TF_PRunSetup(TF_DeprecatedSession* s,
516                   // Input names
517                   const char** c_input_names, int ninputs,
518                   // Output names
519                   const char** c_output_names, int noutputs,
520                   // Target nodes
521                   const char** c_target_oper_names, int ntargets,
522                   const char** handle, TF_Status* status) {
523   *handle = nullptr;
524 
525   std::vector<string> input_names(ninputs);
526   std::vector<string> output_names(noutputs);
527   std::vector<string> target_oper_names(ntargets);
528   for (int i = 0; i < ninputs; ++i) {
529     input_names[i] = c_input_names[i];
530   }
531   for (int i = 0; i < noutputs; ++i) {
532     output_names[i] = c_output_names[i];
533   }
534   for (int i = 0; i < ntargets; ++i) {
535     target_oper_names[i] = c_target_oper_names[i];
536   }
537   string new_handle;
538   status->status = s->session->PRunSetup(input_names, output_names,
539                                          target_oper_names, &new_handle);
540   if (status->status.ok()) {
541     char* buf = new char[new_handle.size() + 1];
542     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
543     *handle = buf;
544   }
545 }
546 
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)547 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
548              // Input tensors
549              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
550              // Output tensors
551              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
552              // Target nodes
553              const char** c_target_oper_names, int ntargets,
554              TF_Status* status) {
555   TF_Run_Setup(noutputs, c_outputs, status);
556   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
557   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
558   for (int i = 0; i < ninputs; ++i) {
559     input_pairs[i].first = c_input_names[i];
560   }
561 
562   std::vector<string> output_names(noutputs);
563   for (int i = 0; i < noutputs; ++i) {
564     output_names[i] = c_output_names[i];
565   }
566   std::vector<string> target_oper_names(ntargets);
567   for (int i = 0; i < ntargets; ++i) {
568     target_oper_names[i] = c_target_oper_names[i];
569   }
570   TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
571                 c_outputs, target_oper_names, nullptr, status);
572 }
573 
TF_LoadLibrary(const char * library_filename,TF_Status * status)574 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
575   TF_Library* lib_handle = new TF_Library;
576   status->status = tensorflow::LoadDynamicLibrary(
577       library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
578       &lib_handle->op_list.length);
579   if (!status->status.ok()) {
580     delete lib_handle;
581     return nullptr;
582   }
583   return lib_handle;
584 }
585 
TF_GetOpList(TF_Library * lib_handle)586 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
587 
TF_DeleteLibraryHandle(TF_Library * lib_handle)588 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
589   if (lib_handle == nullptr) return;
590   tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
591   delete lib_handle;
592 }
593 
TF_GetAllOpList()594 TF_Buffer* TF_GetAllOpList() {
595   std::vector<tensorflow::OpDef> op_defs;
596   tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
597   tensorflow::OpList op_list;
598   for (const auto& op : op_defs) {
599     *(op_list.add_op()) = op;
600   }
601   TF_Buffer* ret = TF_NewBuffer();
602   TF_CHECK_OK(MessageToBuffer(op_list, ret));
603   return ret;
604 }
605 
606 // --------------------------------------------------------------------------
607 // ListDevices & SessionListDevices API
608 
TF_DeleteDeviceList(TF_DeviceList * list)609 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
610 
TF_SessionListDevices(TF_Session * session,TF_Status * status)611 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
612   TF_DeviceList* response = new TF_DeviceList;
613   if (session && session->session)
614     status->status = session->session->ListDevices(&response->response);
615   return response;
616 }
617 
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)618 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
619                                                TF_Status* status) {
620   TF_DeviceList* response = new TF_DeviceList;
621   if (session && session->session)
622     status->status = session->session->ListDevices(&response->response);
623   return response;
624 }
625 
TF_DeviceListCount(const TF_DeviceList * list)626 int TF_DeviceListCount(const TF_DeviceList* list) {
627   return list->response.size();
628 }
629 
630 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
631   return_type method_name(const TF_DeviceList* list, const int index,     \
632                           TF_Status* status) {                            \
633     if (list == nullptr) {                                                \
634       status->status = InvalidArgument("list is null!");                  \
635       return err_val;                                                     \
636     }                                                                     \
637     if (index < 0 || index >= list->response.size()) {                    \
638       status->status = InvalidArgument("index out of bounds");            \
639       return err_val;                                                     \
640     }                                                                     \
641     status->status = Status::OK();                                        \
642     return list->response[index].accessor;                                \
643   }
644 
645 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
646 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
647                      nullptr);
648 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
649 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
650 
651 #undef TF_DEVICELIST_METHOD
652 
653 }  // end extern "C"
654 
655 // --------------------------------------------------------------------------
656 // New Graph and Session API
657 
658 // Helper functions -----------------------------------------------------------
659 
660 namespace {
661 
ToOperation(Node * node)662 TF_Operation* ToOperation(Node* node) {
663   return static_cast<TF_Operation*>(static_cast<void*>(node));
664 }
665 
OutputName(const TF_Output & output)666 string OutputName(const TF_Output& output) {
667   return StrCat(output.oper->node.name(), ":", output.index);
668 }
669 
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)670 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
671                                           const char* attr_name,
672                                           TF_Status* status) {
673   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
674   if (attr == nullptr) {
675     status->status = InvalidArgument("Operation '", oper->node.name(),
676                                      "' has no attr named '", attr_name, "'.");
677   }
678   return attr;
679 }
680 
ToTensorId(const TF_Output & output)681 TensorId ToTensorId(const TF_Output& output) {
682   return TensorId(output.oper->node.name(), output.index);
683 }
684 
685 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)686 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
687                                                      int n) {
688   std::vector<tensorflow::Output> outputs(n);
689   for (int i = 0; i < n; ++i) {
690     outputs[i] =
691         tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
692   }
693   return outputs;
694 }
695 
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)696 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
697                           TF_Output* tf_outputs) {
698   for (int i = 0; i < outputs.size(); i++) {
699     tf_outputs[i].oper = ToOperation(outputs[i].node());
700     tf_outputs[i].index = outputs[i].index();
701   }
702 }
703 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
704 
705 }  // namespace
706 
707 // Shape functions -----------------------------------------------------------
708 
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)709 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
710                             const int64_t* dims, const int num_dims,
711                             TF_Status* status) {
712   Node* node = &output.oper->node;
713 
714   mutex_lock l(graph->mu);
715   tensorflow::shape_inference::InferenceContext* ic =
716       graph->refiner.GetContext(node);
717   if (ic == nullptr) {
718     status->status =
719         InvalidArgument("Node ", node->name(), " was not found in the graph");
720     return;
721   }
722   tensorflow::shape_inference::ShapeHandle new_shape =
723       tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
724   status->status = graph->refiner.SetShape(node, output.index, new_shape);
725 }
726 
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)727 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
728                              TF_Status* status) {
729   Node* node = &output.oper->node;
730 
731   mutex_lock l(graph->mu);
732   tensorflow::shape_inference::InferenceContext* ic =
733       graph->refiner.GetContext(node);
734   if (ic == nullptr) {
735     status->status =
736         InvalidArgument("Node ", node->name(), " was not found in the graph");
737     return -1;
738   }
739 
740   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
741 
742   // Unknown rank means the number of dimensions is -1.
743   if (!ic->RankKnown(shape)) {
744     return -1;
745   }
746 
747   return ic->Rank(shape);
748 }
749 
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)750 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
751                             int num_dims, TF_Status* status) {
752   Node* node = &output.oper->node;
753 
754   mutex_lock l(graph->mu);
755   tensorflow::shape_inference::InferenceContext* ic =
756       graph->refiner.GetContext(node);
757   if (ic == nullptr) {
758     status->status =
759         InvalidArgument("Node ", node->name(), " was not found in the graph");
760     return;
761   }
762 
763   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
764 
765   int rank = -1;
766   if (ic->RankKnown(shape)) {
767     rank = ic->Rank(shape);
768   }
769 
770   if (num_dims != rank) {
771     status->status = InvalidArgument("Expected rank is ", num_dims,
772                                      " but actual rank is ", rank);
773     return;
774   }
775 
776   if (num_dims == 0) {
777     // Output shape is a scalar.
778     return;
779   }
780 
781   // Rank is greater than 0, so fill in the values, if known, and
782   // -1 for unknown values.
783   for (int i = 0; i < num_dims; ++i) {
784     auto dim = ic->Dim(shape, i);
785     int64_t value = -1;
786     if (ic->ValueKnown(dim)) {
787       value = ic->Value(dim);
788     }
789     dims[i] = value;
790   }
791 }
792 
793 // TF_OperationDescription functions ------------------------------------------
794 
795 extern "C" {
796 
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)797 TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
798                                                const char* op_type,
799                                                const char* oper_name)
800     TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
801   return new TF_OperationDescription(graph, op_type, oper_name);
802 }
803 
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)804 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
805                                          const char* oper_name) {
806   mutex_lock l(graph->mu);
807   return TF_NewOperationLocked(graph, op_type, oper_name);
808 }
809 
TF_SetDevice(TF_OperationDescription * desc,const char * device)810 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
811   desc->node_builder.Device(device);
812 }
813 
TF_AddInput(TF_OperationDescription * desc,TF_Output input)814 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
815   desc->node_builder.Input(&input.oper->node, input.index);
816 }
817 
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)818 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
819                      int num_inputs) {
820   std::vector<NodeBuilder::NodeOut> input_list;
821   input_list.reserve(num_inputs);
822   for (int i = 0; i < num_inputs; ++i) {
823     input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
824   }
825   desc->node_builder.Input(input_list);
826 }
827 
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)828 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
829   desc->node_builder.ControlInput(&input->node);
830 }
831 
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)832 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
833   desc->colocation_constraints.emplace(
834       StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
835 }
836 
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)837 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
838                       const void* value, size_t length) {
839   tensorflow::StringPiece s(static_cast<const char*>(value), length);
840   desc->node_builder.Attr(attr_name, s);
841 }
842 
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)843 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
844                           const void* const* values, const size_t* lengths,
845                           int num_values) {
846   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
847     desc->colocation_constraints.clear();
848     for (int i = 0; i < num_values; ++i) {
849       desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
850                                            lengths[i]);
851     }
852   } else {
853     std::vector<tensorflow::StringPiece> v;
854     v.reserve(num_values);
855     for (int i = 0; i < num_values; ++i) {
856       v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
857     }
858     desc->node_builder.Attr(attr_name, v);
859   }
860 }
861 
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)862 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
863                    int64_t value) {
864   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
865                 "64-bit int types should match in size");
866   desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
867 }
868 
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)869 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
870                        const int64_t* values, int num_values) {
871   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
872                 "64-bit int types should match in size");
873   desc->node_builder.Attr(
874       attr_name,
875       ArraySlice<const tensorflow::int64>(
876           reinterpret_cast<const tensorflow::int64*>(values), num_values));
877 }
878 
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)879 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
880                      float value) {
881   desc->node_builder.Attr(attr_name, value);
882 }
883 
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)884 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
885                          const float* values, int num_values) {
886   desc->node_builder.Attr(attr_name,
887                           ArraySlice<const float>(values, num_values));
888 }
889 
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)890 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
891                     unsigned char value) {
892   desc->node_builder.Attr(attr_name, static_cast<bool>(value));
893 }
894 
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)895 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
896                         const unsigned char* values, int num_values) {
897   std::unique_ptr<bool[]> b(new bool[num_values]);
898   for (int i = 0; i < num_values; ++i) {
899     b[i] = values[i];
900   }
901   desc->node_builder.Attr(attr_name,
902                           ArraySlice<const bool>(b.get(), num_values));
903 }
904 
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)905 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
906                     TF_DataType value) {
907   desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
908 }
909 
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)910 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
911                         const TF_DataType* values, int num_values) {
912   desc->node_builder.Attr(
913       attr_name, ArraySlice<const DataType>(
914                      reinterpret_cast<const DataType*>(values), num_values));
915 }
916 
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)917 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
918                            const char* placeholder) {
919   tensorflow::AttrValue attr_value;
920   attr_value.set_placeholder(placeholder);
921   desc->node_builder.Attr(attr_name, attr_value);
922 }
923 
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)924 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
925                         const char* value, size_t length) {
926   tensorflow::NameAttrList func_name;
927   func_name.set_name(string(value, value + length));
928   desc->node_builder.Attr(attr_name, func_name);
929 }
930 
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)931 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
932                      const int64_t* dims, int num_dims) {
933   PartialTensorShape shape;
934   if (num_dims >= 0) {
935     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
936                   "64-bit int types should match in size");
937     shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
938         reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
939   }
940   desc->node_builder.Attr(attr_name, shape);
941 }
942 
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)943 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
944                          const int64_t* const* dims, const int* num_dims,
945                          int num_shapes) {
946   std::vector<PartialTensorShape> shapes;
947   shapes.reserve(num_shapes);
948   for (int i = 0; i < num_shapes; ++i) {
949     if (num_dims[i] < 0) {
950       shapes.emplace_back();
951     } else {
952       static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
953                     "64-bit int types should match in size");
954       shapes.emplace_back(ArraySlice<tensorflow::int64>(
955           reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
956     }
957   }
958   desc->node_builder.Attr(attr_name, shapes);
959 }
960 
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)961 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
962                                 const char* attr_name, const void* proto,
963                                 size_t proto_len, TF_Status* status) {
964   // shape.ParseFromArray takes an int as length, this function takes size_t,
965   // make sure there is no information loss.
966   if (proto_len > std::numeric_limits<int>::max()) {
967     status->status = InvalidArgument(
968         "proto_len (", proto_len,
969         " bytes) is too large to be parsed by the protocol buffer library");
970     return;
971   }
972   TensorShapeProto shape;
973   if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
974     desc->node_builder.Attr(attr_name, shape);
975     status->status = Status::OK();
976   } else {
977     status->status = InvalidArgument("Unparseable TensorShapeProto");
978   }
979 }
980 
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)981 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
982                                     const char* attr_name,
983                                     const void* const* protos,
984                                     const size_t* proto_lens, int num_shapes,
985                                     TF_Status* status) {
986   std::vector<TensorShapeProto> shapes;
987   shapes.resize(num_shapes);
988   for (int i = 0; i < num_shapes; ++i) {
989     if (proto_lens[i] > std::numeric_limits<int>::max()) {
990       status->status = InvalidArgument(
991           "length of element ", i, " in the list (", proto_lens[i],
992           " bytes) is too large to be parsed by the protocol buffer library");
993       return;
994     }
995     if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
996       status->status =
997           InvalidArgument("Unparseable TensorShapeProto at index ", i);
998       return;
999     }
1000   }
1001   desc->node_builder.Attr(attr_name, shapes);
1002   status->status = Status::OK();
1003 }
1004 
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1005 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1006                       TF_Tensor* value, TF_Status* status) {
1007   Tensor t;
1008   status->status = TF_TensorToTensor(value, &t);
1009   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1010 }
1011 
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1012 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1013                           TF_Tensor* const* values, int num_values,
1014                           TF_Status* status) {
1015   status->status = Status::OK();
1016   std::vector<Tensor> t;
1017   t.reserve(num_values);
1018 
1019   for (int i = 0; i < num_values && status->status.ok(); ++i) {
1020     Tensor v;
1021     status->status = TF_TensorToTensor(values[i], &v);
1022     t.emplace_back(v);
1023   }
1024 
1025   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1026 }
1027 
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1028 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1029                           const void* proto, size_t proto_len,
1030                           TF_Status* status) {
1031   tensorflow::AttrValue attr_value;
1032   if (!attr_value.ParseFromArray(proto, proto_len)) {
1033     status->status = InvalidArgument("Unparseable AttrValue proto");
1034     return;
1035   }
1036 
1037   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1038     if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1039         attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1040       status->status =
1041           InvalidArgument("Expected \"list\" field for \"",
1042                           tensorflow::kColocationAttrName, "\" attribute");
1043       return;
1044     }
1045     desc->colocation_constraints.clear();
1046     for (const string& location : attr_value.list().s()) {
1047       desc->colocation_constraints.insert(location);
1048     }
1049   } else {
1050     desc->node_builder.Attr(attr_name, std::move(attr_value));
1051   }
1052 
1053   status->status = Status::OK();
1054 }
1055 
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1056 TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1057                                        TF_Status* status)
1058     TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1059   Node* ret = nullptr;
1060 
1061   if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1062     status->status = InvalidArgument("Duplicate node name in graph: '",
1063                                      desc->node_builder.node_name(), "'");
1064   } else {
1065     if (!desc->colocation_constraints.empty()) {
1066       desc->node_builder.Attr(
1067           tensorflow::kColocationAttrName,
1068           std::vector<string>(desc->colocation_constraints.begin(),
1069                               desc->colocation_constraints.end()));
1070     }
1071     status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1072                                                  /*consume=*/true);
1073 
1074     if (status->status.ok()) {
1075       // Run shape inference function for newly added node.
1076       status->status = desc->graph->refiner.AddNode(ret);
1077     }
1078     if (status->status.ok()) {
1079       // Add the node to the name-to-node mapping.
1080       desc->graph->name_map[ret->name()] = ret;
1081     } else if (ret != nullptr) {
1082       desc->graph->graph.RemoveNode(ret);
1083       ret = nullptr;
1084     }
1085   }
1086 
1087   delete desc;
1088 
1089   return ToOperation(ret);
1090 }
1091 
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1092 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1093                                  TF_Status* status) {
1094   mutex_lock l(desc->graph->mu);
1095   return TF_FinishOperationLocked(desc, status);
1096 }
1097 
1098 // TF_Operation functions
1099 // ----------------------------------------------------------
1100 
TF_OperationName(TF_Operation * oper)1101 const char* TF_OperationName(TF_Operation* oper) {
1102   return oper->node.name().c_str();
1103 }
1104 
TF_OperationOpType(TF_Operation * oper)1105 const char* TF_OperationOpType(TF_Operation* oper) {
1106   return oper->node.type_string().c_str();
1107 }
1108 
TF_OperationDevice(TF_Operation * oper)1109 const char* TF_OperationDevice(TF_Operation* oper) {
1110   return oper->node.requested_device().c_str();
1111 }
1112 
TF_OperationNumOutputs(TF_Operation * oper)1113 int TF_OperationNumOutputs(TF_Operation* oper) {
1114   return oper->node.num_outputs();
1115 }
1116 
TF_OperationOutputType(TF_Output oper_out)1117 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1118   return static_cast<TF_DataType>(
1119       oper_out.oper->node.output_type(oper_out.index));
1120 }
1121 
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1122 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1123                                  TF_Status* status) {
1124   NameRangeMap name_ranges;
1125   status->status =
1126       NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1127   if (!status->status.ok()) return -1;
1128   auto iter = name_ranges.find(arg_name);
1129   if (iter == name_ranges.end()) {
1130     status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1131     return -1;
1132   }
1133   return iter->second.second - iter->second.first;
1134 }
1135 
TF_OperationNumInputs(TF_Operation * oper)1136 int TF_OperationNumInputs(TF_Operation* oper) {
1137   return oper->node.num_inputs();
1138 }
1139 
TF_OperationInputType(TF_Input oper_in)1140 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1141   return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1142 }
1143 
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1144 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1145                                 TF_Status* status) {
1146   NameRangeMap name_ranges;
1147   status->status =
1148       NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1149   if (!status->status.ok()) return -1;
1150   auto iter = name_ranges.find(arg_name);
1151   if (iter == name_ranges.end()) {
1152     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1153     return -1;
1154   }
1155   return iter->second.second - iter->second.first;
1156 }
1157 
TF_OperationInput(TF_Input oper_in)1158 TF_Output TF_OperationInput(TF_Input oper_in) {
1159   const tensorflow::Edge* edge;
1160   Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1161   if (!s.ok()) {
1162     return {nullptr, -1};
1163   }
1164 
1165   return {ToOperation(edge->src()), edge->src_output()};
1166 }
1167 
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1168 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1169                            int max_inputs) {
1170   for (auto* edge : oper->node.in_edges()) {
1171     if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1172       inputs[edge->dst_input()] = {ToOperation(edge->src()),
1173                                    edge->src_output()};
1174     }
1175   }
1176 }
1177 
TF_OperationOutputNumConsumers(TF_Output oper_out)1178 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1179   int count = 0;
1180   for (const auto* edge : oper_out.oper->node.out_edges()) {
1181     if (edge->src_output() == oper_out.index) {
1182       ++count;
1183     }
1184   }
1185   return count;
1186 }
1187 
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1188 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1189                                 int max_consumers) {
1190   int count = 0;
1191   for (const auto* edge : oper_out.oper->node.out_edges()) {
1192     if (edge->src_output() == oper_out.index) {
1193       if (count < max_consumers) {
1194         consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1195       }
1196       ++count;
1197     }
1198   }
1199   return count;
1200 }
1201 
TF_OperationNumControlInputs(TF_Operation * oper)1202 int TF_OperationNumControlInputs(TF_Operation* oper) {
1203   int count = 0;
1204   for (const auto* edge : oper->node.in_edges()) {
1205     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1206       ++count;
1207     }
1208   }
1209   return count;
1210 }
1211 
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1212 int TF_OperationGetControlInputs(TF_Operation* oper,
1213                                  TF_Operation** control_inputs,
1214                                  int max_control_inputs) {
1215   int count = 0;
1216   for (const auto* edge : oper->node.in_edges()) {
1217     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1218       if (count < max_control_inputs) {
1219         control_inputs[count] = ToOperation(edge->src());
1220       }
1221       ++count;
1222     }
1223   }
1224   return count;
1225 }
1226 
TF_OperationNumControlOutputs(TF_Operation * oper)1227 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1228   int count = 0;
1229   for (const auto* edge : oper->node.out_edges()) {
1230     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1231       ++count;
1232     }
1233   }
1234   return count;
1235 }
1236 
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1237 int TF_OperationGetControlOutputs(TF_Operation* oper,
1238                                   TF_Operation** control_outputs,
1239                                   int max_control_outputs) {
1240   int count = 0;
1241   for (const auto* edge : oper->node.out_edges()) {
1242     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1243       if (count < max_control_outputs) {
1244         control_outputs[count] = ToOperation(edge->dst());
1245       }
1246       ++count;
1247     }
1248   }
1249   return count;
1250 }
1251 
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1252 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1253                                             const char* attr_name,
1254                                             TF_Status* status) {
1255   TF_AttrMetadata metadata;
1256   const auto* attr = GetAttrValue(oper, attr_name, status);
1257   if (!status->status.ok()) return metadata;
1258   switch (attr->value_case()) {
1259 #define SINGLE_CASE(kK, attr_type, size_expr) \
1260   case tensorflow::AttrValue::kK:             \
1261     metadata.is_list = 0;                     \
1262     metadata.list_size = -1;                  \
1263     metadata.type = attr_type;                \
1264     metadata.total_size = size_expr;          \
1265     break;
1266 
1267     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1268     SINGLE_CASE(kI, TF_ATTR_INT, -1);
1269     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1270     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1271     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1272     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1273                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1274     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1275 #undef SINGLE_CASE
1276 
1277     case tensorflow::AttrValue::kList:
1278       metadata.is_list = 1;
1279       metadata.list_size = 0;
1280       metadata.total_size = -1;
1281 #define LIST_CASE(field, attr_type, ...)              \
1282   if (attr->list().field##_size() > 0) {              \
1283     metadata.type = attr_type;                        \
1284     metadata.list_size = attr->list().field##_size(); \
1285     __VA_ARGS__;                                      \
1286     break;                                            \
1287   }
1288 
1289       LIST_CASE(
1290           s, TF_ATTR_STRING, metadata.total_size = 0;
1291           for (int i = 0; i < attr->list().s_size();
1292                ++i) { metadata.total_size += attr->list().s(i).size(); });
1293       LIST_CASE(i, TF_ATTR_INT);
1294       LIST_CASE(f, TF_ATTR_FLOAT);
1295       LIST_CASE(b, TF_ATTR_BOOL);
1296       LIST_CASE(type, TF_ATTR_TYPE);
1297       LIST_CASE(
1298           shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1299           for (int i = 0; i < attr->list().shape_size(); ++i) {
1300             const auto& s = attr->list().shape(i);
1301             metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1302           });
1303       LIST_CASE(tensor, TF_ATTR_TENSOR);
1304       LIST_CASE(tensor, TF_ATTR_FUNC);
1305 #undef LIST_CASE
1306       // All lists empty, determine the type from the OpDef.
1307       if (metadata.list_size == 0) {
1308         for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1309           const auto& a = oper->node.op_def().attr(i);
1310           if (a.name() != attr_name) continue;
1311           const string& typestr = a.type();
1312           if (typestr == "list(string)") {
1313             metadata.type = TF_ATTR_STRING;
1314           } else if (typestr == "list(int)") {
1315             metadata.type = TF_ATTR_INT;
1316           } else if (typestr == "list(float)") {
1317             metadata.type = TF_ATTR_FLOAT;
1318           } else if (typestr == "list(bool)") {
1319             metadata.type = TF_ATTR_BOOL;
1320           } else if (typestr == "list(type)") {
1321             metadata.type = TF_ATTR_TYPE;
1322           } else if (typestr == "list(shape)") {
1323             metadata.type = TF_ATTR_SHAPE;
1324           } else if (typestr == "list(tensor)") {
1325             metadata.type = TF_ATTR_TENSOR;
1326           } else if (typestr == "list(func)") {
1327             metadata.type = TF_ATTR_FUNC;
1328           } else {
1329             status->status = InvalidArgument(
1330                 "Attribute '", attr_name,
1331                 "' has an empty value of an unrecognized type '", typestr, "'");
1332             return metadata;
1333           }
1334         }
1335       }
1336       break;
1337 
1338     case tensorflow::AttrValue::kPlaceholder:
1339       metadata.is_list = 0;
1340       metadata.list_size = -1;
1341       metadata.type = TF_ATTR_PLACEHOLDER;
1342       metadata.total_size = -1;
1343       break;
1344 
1345     case tensorflow::AttrValue::kFunc:
1346       metadata.is_list = 0;
1347       metadata.list_size = -1;
1348       metadata.type = TF_ATTR_FUNC;
1349       metadata.total_size = -1;
1350       break;
1351 
1352     case tensorflow::AttrValue::VALUE_NOT_SET:
1353       status->status =
1354           InvalidArgument("Attribute '", attr_name, "' has no value set");
1355       break;
1356   }
1357   return metadata;
1358 }
1359 
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1360 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1361                                void* value, size_t max_length,
1362                                TF_Status* status) {
1363   const auto* attr = GetAttrValue(oper, attr_name, status);
1364   if (!status->status.ok()) return;
1365   if (attr->value_case() != tensorflow::AttrValue::kS) {
1366     status->status =
1367         InvalidArgument("Attribute '", attr_name, "' is not a string");
1368     return;
1369   }
1370   if (max_length <= 0) {
1371     return;
1372   }
1373   const auto& s = attr->s();
1374   std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1375 }
1376 
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1377 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1378                                    void** values, size_t* lengths,
1379                                    int max_values, void* storage,
1380                                    size_t storage_size, TF_Status* status) {
1381   const auto* attr = GetAttrValue(oper, attr_name, status);
1382   if (!status->status.ok()) return;
1383   if (attr->value_case() != tensorflow::AttrValue::kList) {
1384     status->status =
1385         InvalidArgument("Value for '", attr_name, "' is not a list");
1386     return;
1387   }
1388   const auto len = std::min(max_values, attr->list().s_size());
1389   char* p = static_cast<char*>(storage);
1390   for (int i = 0; i < len; ++i) {
1391     const string& s = attr->list().s(i);
1392     values[i] = p;
1393     lengths[i] = s.size();
1394     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1395       status->status = InvalidArgument(
1396           "Not enough storage to hold the requested list of strings");
1397       return;
1398     }
1399     memcpy(values[i], s.data(), s.size());
1400     p += s.size();
1401   }
1402 }
1403 
1404 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1405   void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1406             TF_Status* status) {                                             \
1407     cpp_type v;                                                              \
1408     status->status =                                                         \
1409         tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1410     if (!status->status.ok()) return;                                        \
1411     *value = static_cast<c_type>(v);                                         \
1412   }                                                                          \
1413   void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1414                   int max_values, TF_Status* status) {                       \
1415     const auto* attr = GetAttrValue(oper, attr_name, status);                \
1416     if (!status->status.ok()) return;                                        \
1417     if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1418       status->status =                                                       \
1419           InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1420       return;                                                                \
1421     }                                                                        \
1422     const auto len = std::min(max_values, attr->list().list_field##_size()); \
1423     for (int i = 0; i < len; ++i) {                                          \
1424       values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1425     }                                                                        \
1426   }
1427 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, int64_t, i);
1428 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1429 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1430 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1431 #undef DEFINE_GETATTR
1432 
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1433 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1434                               int64_t* value, int num_dims, TF_Status* status) {
1435   PartialTensorShape shape;
1436   status->status =
1437       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1438   if (!status->status.ok()) return;
1439   auto len = std::min(shape.dims(), num_dims);
1440   for (int i = 0; i < len; ++i) {
1441     value[i] = shape.dim_size(i);
1442   }
1443 }
1444 
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1445 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1446                                   int64_t** dims, int* num_dims, int num_shapes,
1447                                   int64_t* storage, int storage_size,
1448                                   TF_Status* status) {
1449   std::vector<PartialTensorShape> shapes;
1450   status->status =
1451       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1452   if (!status->status.ok()) return;
1453   auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1454   int64_t* p = storage;
1455   int storage_left = storage_size;
1456   for (int i = 0; i < len; ++i) {
1457     // shapes[i].dims() == -1 for shapes with an unknown rank.
1458     int64_t n = shapes[i].dims();
1459     num_dims[i] = n;
1460     dims[i] = p;
1461     if (n < 0) {
1462       continue;
1463     }
1464     if (storage_left < n) {
1465       status->status = InvalidArgument(
1466           "Not enough storage to hold the requested list of shapes");
1467       return;
1468     }
1469     storage_left -= n;
1470     for (int j = 0; j < n; ++j, ++p) {
1471       *p = shapes[i].dim_size(j);
1472     }
1473   }
1474 }
1475 
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1476 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1477                                          const char* attr_name,
1478                                          TF_Buffer* value, TF_Status* status) {
1479   const auto* attr = GetAttrValue(oper, attr_name, status);
1480   if (!status->status.ok()) return;
1481   if (attr->value_case() != tensorflow::AttrValue::kShape) {
1482     status->status =
1483         InvalidArgument("Value for '", attr_name, "' is not a shape.");
1484     return;
1485   }
1486   status->status = MessageToBuffer(attr->shape(), value);
1487 }
1488 
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1489 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1490                                              const char* attr_name,
1491                                              TF_Buffer** values, int max_values,
1492                                              TF_Status* status) {
1493   const auto* attr = GetAttrValue(oper, attr_name, status);
1494   if (!status->status.ok()) return;
1495   if (attr->value_case() != tensorflow::AttrValue::kList) {
1496     status->status =
1497         InvalidArgument("Value for '", attr_name, "' is not a list");
1498     return;
1499   }
1500   const auto len = std::min(max_values, attr->list().shape_size());
1501   for (int i = 0; i < len; ++i) {
1502     values[i] = TF_NewBuffer();
1503     status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1504     if (!status->status.ok()) {
1505       // Delete everything allocated to far, the operation has failed.
1506       for (int j = 0; j <= i; ++j) {
1507         TF_DeleteBuffer(values[j]);
1508       }
1509       return;
1510     }
1511   }
1512 }
1513 
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1514 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1515                                TF_Tensor** value, TF_Status* status) {
1516   *value = nullptr;
1517   Tensor t;
1518   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1519   if (!status->status.ok()) return;
1520   *value = TF_TensorFromTensor(t, &status->status);
1521 }
1522 
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1523 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1524                                    TF_Tensor** values, int max_values,
1525                                    TF_Status* status) {
1526   std::vector<Tensor> ts;
1527   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1528   if (!status->status.ok()) return;
1529   const auto len = std::min(max_values, static_cast<int>(ts.size()));
1530   for (int i = 0; i < len; ++i) {
1531     values[i] = TF_TensorFromTensor(ts[i], &status->status);
1532   }
1533 }
1534 
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1535 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1536                                    TF_Buffer* output_attr_value,
1537                                    TF_Status* status) {
1538   const auto* attr = GetAttrValue(oper, attr_name, status);
1539   if (!status->status.ok()) return;
1540   status->status = MessageToBuffer(*attr, output_attr_value);
1541 }
1542 
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1543 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1544                            TF_Status* status) {
1545   status->status = MessageToBuffer(oper->node.def(), output_node_def);
1546 }
1547 
1548 // TF_Graph functions ---------------------------------------------------------
1549 
TF_Graph()1550 TF_Graph::TF_Graph()
1551     : graph(tensorflow::OpRegistry::Global()),
1552       refiner(graph.versions().producer(), graph.op_registry()),
1553       delete_requested(false),
1554       parent(nullptr),
1555       parent_inputs(nullptr) {
1556   // Tell the shape refiner to also run shape inference on functions.
1557   refiner.set_function_library_for_shape_inference(&graph.flib_def());
1558 }
1559 
TF_NewGraph()1560 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1561 
TF_DeleteGraph(TF_Graph * g)1562 void TF_DeleteGraph(TF_Graph* g) {
1563   if (g == nullptr) return;
1564   g->mu.lock();
1565   g->delete_requested = true;
1566   const bool del = g->sessions.empty();
1567   g->mu.unlock();
1568   if (del) delete g;
1569 }
1570 
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1571 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1572   mutex_lock l(graph->mu);
1573   auto iter = graph->name_map.find(oper_name);
1574   if (iter == graph->name_map.end()) {
1575     return nullptr;
1576   } else {
1577     return ToOperation(iter->second);
1578   }
1579 }
1580 
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1581 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1582   if (*pos == 0) {
1583     // Advance past the first sentinel nodes in every graph (the source & sink).
1584     *pos += 2;
1585   } else {
1586     // Advance to the next node.
1587     *pos += 1;
1588   }
1589 
1590   mutex_lock l(graph->mu);
1591   while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1592     Node* node = graph->graph.FindNodeId(*pos);
1593     // FindNodeId() returns nullptr for nodes that have been deleted.
1594     // We aren't currently allowing nodes to be deleted, but it is safer
1595     // to still check.
1596     if (node != nullptr) return ToOperation(node);
1597     *pos += 1;
1598   }
1599 
1600   // No more nodes.
1601   return nullptr;
1602 }
1603 
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1604 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1605                         TF_Status* status) {
1606   GraphDef def;
1607   {
1608     mutex_lock l(graph->mu);
1609     graph->graph.ToGraphDef(&def);
1610   }
1611   status->status = MessageToBuffer(def, output_graph_def);
1612 }
1613 
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1614 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1615                       TF_Buffer* output_op_def, TF_Status* status) {
1616   const OpDef* op_def;
1617   {
1618     mutex_lock l(graph->mu);
1619     status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1620     if (!status->status.ok()) return;
1621   }
1622   status->status = MessageToBuffer(*op_def, output_op_def);
1623 }
1624 
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1625 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1626                       TF_Status* status) {
1627   VersionDef versions;
1628   {
1629     mutex_lock l(graph->mu);
1630     versions = graph->graph.versions();
1631   }
1632   status->status = MessageToBuffer(versions, output_version_def);
1633 }
1634 
TF_NewImportGraphDefOptions()1635 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1636   return new TF_ImportGraphDefOptions;
1637 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1638 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1639   delete opts;
1640 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1641 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1642                                        const char* prefix) {
1643   opts->opts.prefix = prefix;
1644 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1645 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1646                                               const char* device) {
1647   opts->opts.default_device = device;
1648 }
1649 
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1650 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1651                                               unsigned char uniquify_names) {
1652   opts->opts.uniquify_names = uniquify_names;
1653 }
1654 
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1655 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1656                                                unsigned char uniquify_prefix) {
1657   opts->opts.uniquify_prefix = uniquify_prefix;
1658 }
1659 
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1660 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1661                                              const char* src_name,
1662                                              int src_index, TF_Output dst) {
1663   opts->tensor_id_data.push_back(src_name);
1664   const string& src_name_str = opts->tensor_id_data.back();
1665   // We don't need to store dst's name in tensor_id_data, since `dst` must
1666   // outlive the ImportGraphDef call.
1667   opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1668 }
1669 
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1670 void TF_ImportGraphDefOptionsRemapControlDependency(
1671     TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1672   opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1673       TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1674 }
1675 
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1676 extern void TF_ImportGraphDefOptionsAddControlDependency(
1677     TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1678   opts->opts.control_dependencies.push_back(oper->node.name());
1679 }
1680 
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1681 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1682                                              const char* oper_name, int index) {
1683   opts->tensor_id_data.push_back(oper_name);
1684   const string& oper_name_str = opts->tensor_id_data.back();
1685   opts->opts.return_tensors.emplace_back(oper_name_str, index);
1686 }
1687 
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1688 int TF_ImportGraphDefOptionsNumReturnOutputs(
1689     const TF_ImportGraphDefOptions* opts) {
1690   return opts->opts.return_tensors.size();
1691 }
1692 
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1693 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1694                                                 const char* oper_name) {
1695   opts->opts.return_nodes.push_back(oper_name);
1696 }
1697 
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1698 int TF_ImportGraphDefOptionsNumReturnOperations(
1699     const TF_ImportGraphDefOptions* opts) {
1700   return opts->opts.return_nodes.size();
1701 }
1702 
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1703 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1704                                            int* num_outputs,
1705                                            TF_Output** outputs) {
1706   *num_outputs = results->return_tensors.size();
1707   *outputs = results->return_tensors.data();
1708 }
1709 
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1710 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1711                                               int* num_opers,
1712                                               TF_Operation*** opers) {
1713   *num_opers = results->return_nodes.size();
1714   *opers = results->return_nodes.data();
1715 }
1716 
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1717 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1718     TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1719     const char*** src_names, int** src_indexes) {
1720   *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1721   *src_names = results->missing_unused_key_names.data();
1722   *src_indexes = results->missing_unused_key_indexes.data();
1723 }
1724 
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1725 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1726   delete results;
1727 }
1728 
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1729 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1730                                       const TF_ImportGraphDefOptions* opts,
1731                                       TF_ImportGraphDefResults* tf_results,
1732                                       TF_Status* status)
1733     TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1734   const int last_node_id = graph->graph.num_node_ids();
1735   tensorflow::ImportGraphDefResults results;
1736   status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1737                                               &graph->refiner, &results);
1738   if (!status->status.ok()) return;
1739 
1740   // Add new nodes to name_map
1741   for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1742     auto* node = graph->graph.FindNodeId(i);
1743     if (node != nullptr) graph->name_map[node->name()] = node;
1744   }
1745 
1746   // Populate return_tensors
1747   DCHECK(tf_results->return_tensors.empty());
1748   tf_results->return_tensors.resize(results.return_tensors.size());
1749   for (int i = 0; i < results.return_tensors.size(); ++i) {
1750     tf_results->return_tensors[i].oper =
1751         ToOperation(results.return_tensors[i].first);
1752     tf_results->return_tensors[i].index = results.return_tensors[i].second;
1753   }
1754 
1755   // Populate return_nodes
1756   DCHECK(tf_results->return_nodes.empty());
1757   tf_results->return_nodes.resize(results.return_nodes.size());
1758   for (int i = 0; i < results.return_nodes.size(); ++i) {
1759     tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1760   }
1761 
1762   // Populate missing unused map keys
1763   DCHECK(tf_results->missing_unused_key_names.empty());
1764   DCHECK(tf_results->missing_unused_key_indexes.empty());
1765   DCHECK(tf_results->missing_unused_key_names_data.empty());
1766 
1767   size_t size = results.missing_unused_input_map_keys.size();
1768   tf_results->missing_unused_key_names.resize(size);
1769   tf_results->missing_unused_key_indexes.resize(size);
1770 
1771   for (int i = 0; i < size; ++i) {
1772     TensorId id = results.missing_unused_input_map_keys[i];
1773     tf_results->missing_unused_key_names_data.emplace_back(id.first);
1774     tf_results->missing_unused_key_names[i] =
1775         tf_results->missing_unused_key_names_data.back().c_str();
1776     tf_results->missing_unused_key_indexes[i] = id.second;
1777   }
1778 }
1779 
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1780 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1781     TF_Graph* graph, const TF_Buffer* graph_def,
1782     const TF_ImportGraphDefOptions* options, TF_Status* status) {
1783   GraphDef def;
1784   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1785                                        graph_def->length)) {
1786     status->status = InvalidArgument("Invalid GraphDef");
1787     return nullptr;
1788   }
1789   auto results = new TF_ImportGraphDefResults();
1790   mutex_lock l(graph->mu);
1791   GraphImportGraphDefLocked(graph, def, options, results, status);
1792   if (!status->status.ok()) {
1793     delete results;
1794     return nullptr;
1795   }
1796   return results;
1797 }
1798 
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1799 void TF_GraphImportGraphDefWithReturnOutputs(
1800     TF_Graph* graph, const TF_Buffer* graph_def,
1801     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1802     int num_return_outputs, TF_Status* status) {
1803   if (num_return_outputs != options->opts.return_tensors.size()) {
1804     status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1805                                      options->opts.return_tensors.size(),
1806                                      ", got ", num_return_outputs);
1807     return;
1808   }
1809   if (num_return_outputs > 0 && return_outputs == nullptr) {
1810     status->status = InvalidArgument(
1811         "'return_outputs' must be preallocated to length ", num_return_outputs);
1812     return;
1813   }
1814   GraphDef def;
1815   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1816                                        graph_def->length)) {
1817     status->status = InvalidArgument("Invalid GraphDef");
1818     return;
1819   }
1820   TF_ImportGraphDefResults results;
1821   mutex_lock l(graph->mu);
1822   GraphImportGraphDefLocked(graph, def, options, &results, status);
1823   DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1824   memcpy(return_outputs, results.return_tensors.data(),
1825          num_return_outputs * sizeof(TF_Output));
1826 }
1827 
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1828 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1829                             const TF_ImportGraphDefOptions* options,
1830                             TF_Status* status) {
1831   TF_ImportGraphDefResults* results =
1832       TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1833   TF_DeleteImportGraphDefResults(results);
1834 }
1835 
1836 // While loop functions -------------------------------------------------------
1837 
1838 namespace {
1839 
1840 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1841 
1842 // Creates a placeholder representing an input to the cond or body graph.
1843 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1844 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1845                  TF_Output* input, TF_Status* status) {
1846   TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1847   TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1848   // TODO(skyewm): set placeholder shape
1849   TF_Operation* oper = TF_FinishOperation(desc, status);
1850   if (!status->status.ok()) return false;
1851   *input = {oper, 0};
1852   return true;
1853 }
1854 
1855 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1856 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
1857 // will be prepended to copied node names. `control_deps` are nodes in
1858 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1859 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1860 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1861 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1862                  tensorflow::ShapeRefiner* dst_refiner,
1863                  const TF_Output* src_inputs,
1864                  const std::vector<tensorflow::Output>& dst_inputs,
1865                  const string& prefix,
1866                  const std::vector<tensorflow::Operation>& control_deps,
1867                  const TF_Output* nodes_to_return, int nreturn_nodes,
1868                  std::vector<tensorflow::Output>* return_nodes) {
1869   DCHECK(return_nodes != nullptr);
1870   GraphDef gdef;
1871   src_graph->ToGraphDef(&gdef);
1872 
1873   tensorflow::ImportGraphDefOptions opts;
1874   opts.prefix = prefix;
1875 
1876   for (int i = 0; i < dst_inputs.size(); ++i) {
1877     opts.input_map[ToTensorId(src_inputs[i])] =
1878         TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1879   }
1880   opts.skip_mapped_nodes = true;
1881 
1882   for (const tensorflow::Operation& op : control_deps) {
1883     opts.control_dependencies.push_back(op.node()->name());
1884   }
1885 
1886   for (int i = 0; i < nreturn_nodes; ++i) {
1887     opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1888   }
1889 
1890   // TODO(skyewm): change to OutputTensor
1891   tensorflow::ImportGraphDefResults results;
1892   TF_RETURN_IF_ERROR(
1893       ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1894 
1895   for (const auto& pair : results.return_tensors) {
1896     return_nodes->emplace_back(pair.first, pair.second);
1897   }
1898   return Status::OK();
1899 }
1900 
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1901 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1902   if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1903       params.cond_graph->parent == nullptr ||
1904       params.cond_graph->parent != params.body_graph->parent ||
1905       params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1906       params.ninputs <= 0 || params.cond_inputs == nullptr ||
1907       params.body_inputs == nullptr || params.body_outputs == nullptr) {
1908     s->status = InvalidArgument(
1909         "TF_WhileParams must be created by successful TF_NewWhile() call");
1910     return false;
1911   }
1912   return true;
1913 }
1914 
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1915 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1916   if (params.cond_output.oper == nullptr) {
1917     s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1918     return false;
1919   }
1920   for (int i = 0; i < params.ninputs; ++i) {
1921     if (params.body_outputs[i].oper == nullptr) {
1922       s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1923                                   "field isn't set");
1924       return false;
1925     }
1926   }
1927   if (params.name == nullptr) {
1928     s->status = InvalidArgument("TF_WhileParams `name` field is null");
1929     return false;
1930   }
1931   return true;
1932 }
1933 
1934 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1935 
FreeWhileResources(const TF_WhileParams * params)1936 void FreeWhileResources(const TF_WhileParams* params) {
1937   TF_DeleteGraph(params->cond_graph);
1938   TF_DeleteGraph(params->body_graph);
1939   delete[] params->cond_inputs;
1940   delete[] params->body_inputs;
1941   delete[] params->body_outputs;
1942 }
1943 
EmptyWhileParams()1944 TF_WhileParams EmptyWhileParams() {
1945   return {0,       nullptr, nullptr, {nullptr, 0},
1946           nullptr, nullptr, nullptr, nullptr};
1947 }
1948 
1949 }  // namespace
1950 
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1951 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1952                            TF_Status* status) {
1953 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1954   status->status = tensorflow::errors::Unimplemented(
1955       "Creating while loops is not supported on mobile. File a bug at "
1956       "https://github.com/tensorflow/tensorflow/issues if this feature is "
1957       "important to you");
1958   return EmptyWhileParams();
1959 #else
1960   if (ninputs == 0) {
1961     status->status =
1962         InvalidArgument("TF_NewWhile() must be passed at least one input");
1963     return EmptyWhileParams();
1964   }
1965 
1966   TF_Graph* cond_graph = TF_NewGraph();
1967   TF_Graph* body_graph = TF_NewGraph();
1968   cond_graph->parent = g;
1969   cond_graph->parent_inputs = inputs;
1970   body_graph->parent = g;
1971   body_graph->parent_inputs = inputs;
1972 
1973   TF_Output* cond_inputs = new TF_Output[ninputs];
1974   TF_Output cond_output = {nullptr, -1};
1975   TF_Output* body_inputs = new TF_Output[ninputs];
1976   TF_Output* body_outputs = new TF_Output[ninputs];
1977   for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1978   const char* name = nullptr;
1979 
1980   for (int i = 0; i < ninputs; ++i) {
1981     // TODO(skyewm): prefix names with underscore (requires some plumbing)
1982     if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1983                      &cond_inputs[i], status)) {
1984       break;
1985     }
1986     if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1987                      &body_inputs[i], status)) {
1988       break;
1989     }
1990   }
1991 
1992   TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
1993                            body_graph, body_inputs, body_outputs, name};
1994 
1995   if (!status->status.ok()) {
1996     FreeWhileResources(&params);
1997     return EmptyWhileParams();
1998   }
1999   return params;
2000 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2001 }
2002 
2003 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2004 namespace {
2005 
2006 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2007 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2008                           TF_Output* outputs) {
2009   if (!ValidateInputWhileParams(*params, status)) return;
2010 
2011   TF_Graph* parent = params->cond_graph->parent;
2012   TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2013   int num_loop_vars = params->ninputs;
2014 
2015   mutex_lock l(parent->mu);
2016 
2017   // 'cond_fn' copies the cond graph into the parent graph.
2018   tensorflow::ops::CondGraphBuilderFn cond_fn =
2019       [params, parent](const tensorflow::Scope& scope,
2020                        const std::vector<tensorflow::Output>& inputs,
2021                        tensorflow::Output* output) {
2022         DCHECK_EQ(scope.graph(), &parent->graph);
2023         std::vector<tensorflow::Output> cond_output;
2024         TF_RETURN_IF_ERROR(CopyGraph(
2025             &params->cond_graph->graph, &parent->graph, &parent->refiner,
2026             params->cond_inputs, inputs, scope.impl()->name(),
2027             scope.impl()->control_deps(), &params->cond_output,
2028             /* nreturn_nodes */ 1, &cond_output));
2029         *output = cond_output[0];
2030         return Status::OK();
2031       };
2032 
2033   // 'body_fn' copies the body graph into the parent graph.
2034   tensorflow::ops::BodyGraphBuilderFn body_fn =
2035       [params, parent, num_loop_vars](
2036           const tensorflow::Scope& scope,
2037           const std::vector<tensorflow::Output>& inputs,
2038           std::vector<tensorflow::Output>* outputs) {
2039         DCHECK_EQ(scope.graph(), &parent->graph);
2040         TF_RETURN_IF_ERROR(
2041             CopyGraph(&params->body_graph->graph, &parent->graph,
2042                       &parent->refiner, params->body_inputs, inputs,
2043                       scope.impl()->name(), scope.impl()->control_deps(),
2044                       params->body_outputs, num_loop_vars, outputs));
2045         return Status::OK();
2046       };
2047 
2048   // Create the while loop using an internal scope.
2049   tensorflow::Scope scope =
2050       NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2051           .NewSubScope(params->name);
2052 
2053   const int first_new_node_id = parent->graph.num_node_ids();
2054 
2055   tensorflow::OutputList loop_outputs;
2056   status->status = tensorflow::ops::BuildWhileLoop(
2057       scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2058       body_fn, params->name, &loop_outputs);
2059 
2060   // Update name_map with newly-created ops.
2061   // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2062   // a bad status. Once we fix this, we may want to return early instead of
2063   // executing the following code.
2064   for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2065     Node* new_node = parent->graph.FindNodeId(i);
2066     if (new_node == nullptr) continue;
2067     parent->name_map[new_node->name()] = new_node;
2068   }
2069 
2070   // Populate 'outputs'.
2071   DCHECK_LE(loop_outputs.size(), num_loop_vars);
2072   for (int i = 0; i < loop_outputs.size(); ++i) {
2073     outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2074   }
2075 }
2076 
2077 }  // namespace
2078 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2079 
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2080 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2081                     TF_Output* outputs) {
2082 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2083   status->status = tensorflow::errors::Unimplemented(
2084       "Creating while loops is not supported on mobile. File a bug at "
2085       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2086       "important to you");
2087 #else
2088   // If it appears the caller created or modified `params`, don't free resources
2089   if (!ValidateConstWhileParams(*params, status)) return;
2090   TF_FinishWhileHelper(params, status, outputs);
2091   FreeWhileResources(params);
2092 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2093 }
2094 
TF_AbortWhile(const TF_WhileParams * params)2095 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2096 
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2097 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2098                      TF_Output* dx, TF_Status* status, TF_Output* dy) {
2099   TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2100 }
2101 
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2102 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2103                                int ny, TF_Output* x, int nx, TF_Output* dx,
2104                                TF_Status* status, TF_Output* dy) {
2105 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2106   status->status = tensorflow::errors::Unimplemented(
2107       "Adding gradients is not supported on mobile. File a bug at "
2108       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2109       "important to you");
2110 #else
2111   std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2112   std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2113   std::vector<tensorflow::Output> dy_arg;
2114 
2115   {
2116     // We need to hold on to the lock while we have a scope that uses TF_Graph.
2117     mutex_lock graph_lock(g->mu);
2118 
2119     const int first_new_node_id = g->graph.num_node_ids();
2120 
2121     string prefix_cmp;
2122     const char* child_scope_name;
2123     if (prefix == nullptr) {
2124       child_scope_name = "gradients";
2125     } else {
2126       prefix_cmp = string(prefix) + "/";
2127       // The operation should fail if the provided name prefix has already been
2128       // used in this graph
2129       for (const auto& pair : g->name_map) {
2130         const string& name = pair.first;
2131         if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2132           status->status = InvalidArgument(
2133               "prefix [", prefix,
2134               "] conflicts with existing node in the graph named [", name, "]");
2135           return;
2136         }
2137       }
2138       child_scope_name = prefix;
2139     }
2140     tensorflow::Scope scope =
2141         NewInternalScope(&g->graph, &status->status, &g->refiner)
2142             .NewSubScope(child_scope_name);
2143 
2144     if (dx != nullptr) {
2145       std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2146       status->status =
2147           AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2148     } else {
2149       status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2150     }
2151 
2152     // Update g->name_map with the name_map from the scope, which will contain
2153     // the new gradient ops.
2154     for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2155       Node* n = g->graph.FindNodeId(i);
2156       if (n == nullptr) continue;
2157 
2158       // Adding the gradients to the graph can alter the prefix to prevent
2159       // name collisions only if this prefix has not been provided explicitly
2160       // by the user. If it was provided, assert that it remained intact.
2161       if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2162         status->status = tensorflow::errors::Internal(
2163             "BUG: The gradients prefix have been unexpectedly altered when "
2164             "adding the nodes to the graph. This is a bug. Please file an "
2165             "issue at https://github.com/tensorflow/tensorflow/issues.");
2166         return;
2167       }
2168       // We have a convoluted scheme here: Using the C++ graph construction API
2169       // to add potentially many nodes to the graph without running the checks
2170       // (such as uniqueness of the names of nodes) we run with other functions
2171       // that add a node to the graph (like TF_FinishOperation).
2172       if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2173         status->status = tensorflow::errors::Internal(
2174             "BUG: The API allowed construction of a graph with duplicate node "
2175             "names (",
2176             n->name(),
2177             "). This is a bug. Please file an issue at "
2178             "https://github.com/tensorflow/tensorflow/issues.");
2179       }
2180     }
2181   }
2182 
2183   // Unpack the results from grad_outputs_arg.
2184   TFOutputsFromOutputs(dy_arg, dy);
2185 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2186 }
2187 
2188 // TF_Session functions ----------------------------------------------
2189 
TF_Session(tensorflow::Session * s,TF_Graph * g)2190 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2191     : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2192 
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2193 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2194                           TF_Status* status) {
2195   Session* session;
2196   status->status = NewSession(opt->options, &session);
2197   if (status->status.ok()) {
2198     TF_Session* new_session = new TF_Session(session, graph);
2199     if (graph != nullptr) {
2200       mutex_lock l(graph->mu);
2201       graph->sessions[new_session] = "";
2202     }
2203     return new_session;
2204   } else {
2205     LOG(ERROR) << status->status;
2206     DCHECK_EQ(nullptr, session);
2207     return nullptr;
2208   }
2209 }
2210 
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2211 TF_Session* TF_LoadSessionFromSavedModel(
2212     const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2213     const char* export_dir, const char* const* tags, int tags_len,
2214     TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2215 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2216 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2217 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2218   status->status = tensorflow::errors::Unimplemented(
2219       "Loading a SavedModel is not supported on mobile. File a bug at "
2220       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2221       "important to you");
2222   return nullptr;
2223 #else
2224   mutex_lock l(graph->mu);
2225   if (!graph->name_map.empty()) {
2226     status->status = InvalidArgument("Graph is non-empty.");
2227     return nullptr;
2228   }
2229 
2230   RunOptions run_options_proto;
2231   if (run_options != nullptr && !run_options_proto.ParseFromArray(
2232                                     run_options->data, run_options->length)) {
2233     status->status = InvalidArgument("Unparseable RunOptions proto");
2234     return nullptr;
2235   }
2236 
2237   std::unordered_set<string> tag_set;
2238   for (int i = 0; i < tags_len; i++) {
2239     tag_set.insert(string(tags[i]));
2240   }
2241 
2242   tensorflow::SavedModelBundle bundle;
2243   status->status =
2244       tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2245                                  export_dir, tag_set, &bundle);
2246   if (!status->status.ok()) return nullptr;
2247 
2248   // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2249   // extends using GraphDefs. The Graph instance is different, but equivalent
2250   // to the one used to create the session.
2251   //
2252   // TODO(jhseu): When Session is modified to take Graphs instead of
2253   // GraphDefs, return the Graph generated in LoadSavedModel().
2254   TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2255   TF_ImportGraphDefResults results;
2256   GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2257                             import_opts, &results, status);
2258   TF_DeleteImportGraphDefOptions(import_opts);
2259   if (!status->status.ok()) return nullptr;
2260 
2261   if (meta_graph_def != nullptr) {
2262     status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2263     if (!status->status.ok()) return nullptr;
2264   }
2265 
2266   TF_Session* session = new TF_Session(bundle.session.release(), graph);
2267 
2268   graph->sessions[session] = "";
2269   session->last_num_graph_nodes = graph->graph.num_node_ids();
2270   return session;
2271 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2272 }
2273 
TF_CloseSession(TF_Session * s,TF_Status * status)2274 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2275   status->status = s->session->Close();
2276 }
2277 
TF_DeleteSession(TF_Session * s,TF_Status * status)2278 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2279   status->status = Status::OK();
2280   if (s == nullptr) return;
2281   TF_Graph* const graph = s->graph;
2282   if (graph != nullptr) {
2283     graph->mu.lock();
2284     graph->sessions.erase(s);
2285     const bool del = graph->delete_requested && graph->sessions.empty();
2286     graph->mu.unlock();
2287     if (del) delete graph;
2288   }
2289   delete s->session;
2290   delete s;
2291 }
2292 
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2293 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2294                    const TF_Output* inputs, TF_Tensor* const* input_values,
2295                    int ninputs, const TF_Output* outputs,
2296                    TF_Tensor** output_values, int noutputs,
2297                    const TF_Operation* const* target_opers, int ntargets,
2298                    TF_Buffer* run_metadata, TF_Status* status) {
2299   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2300   // directly, instead of requiring us to serialize to a GraphDef and
2301   // call Session::Extend().
2302   if (session->extend_before_run &&
2303       !ExtendSessionGraphHelper(session, status)) {
2304     return;
2305   }
2306 
2307   TF_Run_Setup(noutputs, output_values, status);
2308 
2309   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2310   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2311   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2312   for (int i = 0; i < ninputs; ++i) {
2313     input_pairs[i].first = OutputName(inputs[i]);
2314   }
2315 
2316   // Convert from TF_Output to string names.
2317   std::vector<string> output_names(noutputs);
2318   for (int i = 0; i < noutputs; ++i) {
2319     output_names[i] = OutputName(outputs[i]);
2320   }
2321 
2322   // Convert from TF_Operation* to string names.
2323   std::vector<string> target_names(ntargets);
2324   for (int i = 0; i < ntargets; ++i) {
2325     target_names[i] = target_opers[i]->node.name();
2326   }
2327 
2328   // Actually run.
2329   TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2330                 output_names, output_values, target_names, run_metadata,
2331                 status);
2332 }
2333 
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2334 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2335                          int ninputs, const TF_Output* outputs, int noutputs,
2336                          const TF_Operation* const* target_opers, int ntargets,
2337                          const char** handle, TF_Status* status) {
2338   *handle = nullptr;
2339 
2340   if (session->extend_before_run &&
2341       !ExtendSessionGraphHelper(session, status)) {
2342     return;
2343   }
2344 
2345   std::vector<string> input_names(ninputs);
2346   for (int i = 0; i < ninputs; ++i) {
2347     input_names[i] = OutputName(inputs[i]);
2348   }
2349 
2350   std::vector<string> output_names(noutputs);
2351   for (int i = 0; i < noutputs; ++i) {
2352     output_names[i] = OutputName(outputs[i]);
2353   }
2354 
2355   std::vector<string> target_names(ntargets);
2356   for (int i = 0; i < ntargets; ++i) {
2357     target_names[i] = target_opers[i]->node.name();
2358   }
2359 
2360   string new_handle;
2361   status->status = session->session->PRunSetup(input_names, output_names,
2362                                                target_names, &new_handle);
2363   if (status->status.ok()) {
2364     char* buf = new char[new_handle.size() + 1];
2365     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2366     *handle = buf;
2367   }
2368 }
2369 
TF_DeletePRunHandle(const char * handle)2370 void TF_DeletePRunHandle(const char* handle) {
2371   delete[] handle;
2372   // TODO(suharshs): Free up any resources held by the partial run state.
2373 }
2374 
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2375 void TF_SessionPRun(TF_Session* session, const char* handle,
2376                     const TF_Output* inputs, TF_Tensor* const* input_values,
2377                     int ninputs, const TF_Output* outputs,
2378                     TF_Tensor** output_values, int noutputs,
2379                     const TF_Operation* const* target_opers, int ntargets,
2380                     TF_Status* status) {
2381   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2382   // directly, instead of requiring us to serialize to a GraphDef and
2383   // call Session::Extend().
2384   if (session->extend_before_run &&
2385       !ExtendSessionGraphHelper(session, status)) {
2386     return;
2387   }
2388 
2389   TF_Run_Setup(noutputs, output_values, status);
2390 
2391   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2392   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2393   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2394   for (int i = 0; i < ninputs; ++i) {
2395     input_pairs[i].first = OutputName(inputs[i]);
2396   }
2397 
2398   // Convert from TF_Output to string names.
2399   std::vector<string> output_names(noutputs);
2400   for (int i = 0; i < noutputs; ++i) {
2401     output_names[i] = OutputName(outputs[i]);
2402   }
2403 
2404   // Convert from TF_Operation* to string names.
2405   std::vector<string> target_names(ntargets);
2406   for (int i = 0; i < ntargets; ++i) {
2407     target_names[i] = target_opers[i]->node.name();
2408   }
2409 
2410   TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2411                 output_values, target_names, nullptr, status);
2412 }
2413 
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2414 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2415                                      TF_Tensor** result, TF_Status* status) {
2416   *result = nullptr;
2417   mutex_lock l(graph->mu);
2418   OutputTensor tensor(&output.oper->node, output.index);
2419   bool evaluated;
2420   Tensor result_tensor;
2421   status->status = EvaluateConstantTensor(
2422       tensor, graph->refiner, *graph->graph.op_registry(),
2423       graph->graph.versions().producer(), &evaluated, &result_tensor);
2424   if (evaluated) {
2425     DCHECK(status->status.ok());
2426     *result = TF_TensorFromTensor(result_tensor, &status->status);
2427     if (!status->status.ok()) evaluated = false;
2428   }
2429   return evaluated;
2430 }
2431 
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2432 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2433   tensorflow::OpList op_list;
2434   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2435     status->status = InvalidArgument("Unparseable OpList");
2436     return nullptr;
2437   }
2438   status->status = Status::OK();
2439   return new TF_ApiDefMap(op_list);
2440 }
2441 
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2442 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2443 
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2444 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2445                      size_t text_len, TF_Status* status) {
2446 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2447   status->status = tensorflow::errors::Unimplemented(
2448       "ApiDefMap is not supported on mobile.");
2449 #else
2450   mutex_lock l(api_def_map->lock);
2451   if (api_def_map->update_docs_called) {
2452     status->status = FailedPrecondition(
2453         "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2454         "called.");
2455     return;
2456   }
2457   string api_def_text(text, text_len);
2458   status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2459 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2460 }
2461 
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2462 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2463                            size_t name_len, TF_Status* status) {
2464 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2465   status->status = tensorflow::errors::Unimplemented(
2466       "ApiDefMap is not supported on mobile.");
2467   return nullptr;
2468 #else
2469   mutex_lock l(api_def_map->lock);
2470   if (!api_def_map->update_docs_called) {
2471     api_def_map->api_def_map.UpdateDocs();
2472     api_def_map->update_docs_called = true;
2473   }
2474   string name_str(name, name_len);
2475   const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2476   if (api_def == nullptr) {
2477     return nullptr;
2478   }
2479 
2480   TF_Buffer* ret = TF_NewBuffer();
2481   status->status = MessageToBuffer(*api_def, ret);
2482   if (!status->status.ok()) {
2483     TF_DeleteBuffer(ret);
2484     return nullptr;
2485   }
2486   return ret;
2487 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2488 }
2489 
TF_GetAllRegisteredKernels(TF_Status * status)2490 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2491   tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2492   TF_Buffer* ret = TF_NewBuffer();
2493   status->status = MessageToBuffer(kernel_list, ret);
2494   if (!status->status.ok()) {
2495     TF_DeleteBuffer(ret);
2496     return nullptr;
2497   }
2498   return ret;
2499 }
2500 
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2501 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2502   tensorflow::KernelList kernel_list =
2503       tensorflow::GetRegisteredKernelsForOp(name);
2504   TF_Buffer* ret = TF_NewBuffer();
2505   status->status = MessageToBuffer(kernel_list, ret);
2506   if (!status->status.ok()) {
2507     TF_DeleteBuffer(ret);
2508     return nullptr;
2509   }
2510   return ret;
2511 }
2512 
TF_UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)2513 void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2514                    TF_Status* status) {
2515   using tensorflow::RecordMutation;
2516   mutex_lock l(graph->mu);
2517   tensorflow::shape_inference::InferenceContext* ic =
2518       graph->refiner.GetContext(&new_src.oper->node);
2519 
2520   if (ic->num_outputs() <= new_src.index) {
2521     status->status = tensorflow::errors::OutOfRange(
2522         "Cannot update edge. Output index [", new_src.index,
2523         "] is greater than the number of total outputs [", ic->num_outputs(),
2524         "].");
2525     return;
2526   }
2527   tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
2528 
2529   tensorflow::shape_inference::InferenceContext* ic_dst =
2530       graph->refiner.GetContext(&dst.oper->node);
2531   if (ic_dst->num_inputs() <= dst.index) {
2532     status->status = tensorflow::errors::OutOfRange(
2533         "Cannot update edge. Input index [", dst.index,
2534         "] is greater than the number of total inputs [", ic_dst->num_inputs(),
2535         "].");
2536     return;
2537   }
2538   if (!ic_dst->MergeInput(dst.index, shape)) {
2539     status->status = tensorflow::errors::InvalidArgument(
2540         "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
2541         " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
2542     return;
2543   }
2544   status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
2545                                            &dst.oper->node, dst.index);
2546 
2547   if (TF_GetCode(status) == TF_OK) {
2548     // This modification only updates the destination node for
2549     // the purposes of running this graph in a session. Thus, we don't
2550     // record the source node as being modified.
2551     RecordMutation(graph, *dst.oper, "updating input tensor");
2552   }
2553 }
2554 
2555 // TF_Server functions ----------------------------------------------
2556 
2557 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2558 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2559     : target(server->target()), server(std::move(server)) {}
2560 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2561 
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2562 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2563                         TF_Status* status) {
2564 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2565   status->status = tensorflow::errors::Unimplemented(
2566       "Server functionality is not supported on mobile");
2567   return nullptr;
2568 #else
2569   tensorflow::ServerDef server_def;
2570   if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2571     status->status = InvalidArgument(
2572         "Could not parse provided bytes into a ServerDef protocol buffer");
2573     return nullptr;
2574   }
2575 
2576   std::unique_ptr<tensorflow::ServerInterface> out_server;
2577   status->status = tensorflow::NewServer(server_def, &out_server);
2578   if (!status->status.ok()) return nullptr;
2579 
2580   return new TF_Server(std::move(out_server));
2581 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2582 }
2583 
TF_ServerStart(TF_Server * server,TF_Status * status)2584 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2585 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2586   status->status = tensorflow::errors::Unimplemented(
2587       "Server functionality is not supported on mobile");
2588 #else
2589   status->status = server->server->Start();
2590 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2591 }
2592 
TF_ServerStop(TF_Server * server,TF_Status * status)2593 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2594 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2595   status->status = tensorflow::errors::Unimplemented(
2596       "Server functionality is not supported on mobile");
2597 #else
2598   status->status = server->server->Stop();
2599 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2600 }
2601 
TF_ServerJoin(TF_Server * server,TF_Status * status)2602 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2603 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2604   status->status = tensorflow::errors::Unimplemented(
2605       "Server functionality is not supported on mobile");
2606 #else
2607   status->status = server->server->Join();
2608 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2609 }
2610 
TF_ServerTarget(TF_Server * server)2611 const char* TF_ServerTarget(TF_Server* server) {
2612 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2613   return nullptr;
2614 #else
2615   return server->target.c_str();
2616 #endif
2617 }
2618 
TF_DeleteServer(TF_Server * server)2619 void TF_DeleteServer(TF_Server* server) {
2620 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2621   delete server;
2622 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2623 }
2624 
TF_RegisterLogListener(void (* listener)(const char *))2625 void TF_RegisterLogListener(void (*listener)(const char*)) {
2626 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2627   tensorflow::logging::RegisterListener(listener);
2628 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2629 }
2630 
TF_RegisterFilesystemPlugin(const char * plugin_filename,TF_Status * status)2631 void TF_RegisterFilesystemPlugin(const char* plugin_filename,
2632                                  TF_Status* status) {
2633 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2634   status->status = tensorflow::errors::Unimplemented(
2635       "FileSystem plugin functionality is not supported on mobile");
2636 #else
2637   status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
2638 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2639 }
2640 
2641 }  // end extern "C"
2642