• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 #ifndef __ANDROID__
24 #include "tensorflow/cc/framework/gradients.h"
25 #include "tensorflow/cc/framework/ops.h"
26 #include "tensorflow/cc/framework/scope_internal.h"
27 #include "tensorflow/cc/ops/while_loop.h"
28 #include "tensorflow/cc/saved_model/loader.h"
29 #include "tensorflow/core/framework/op_gen_lib.h"
30 #endif
31 #include "tensorflow/c/c_api_internal.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/shape_refiner.h"
34 #include "tensorflow/core/framework/allocation_description.pb.h"
35 #include "tensorflow/core/framework/log_memory.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/partial_tensor_shape.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/versions.pb.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/graph_constructor.h"
46 #include "tensorflow/core/graph/node_builder.h"
47 #include "tensorflow/core/lib/core/coding.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/core/stringpiece.h"
51 #include "tensorflow/core/lib/gtl/array_slice.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/platform/mem.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/protobuf.h"
56 #include "tensorflow/core/platform/thread_annotations.h"
57 #include "tensorflow/core/platform/types.h"
58 #include "tensorflow/core/public/session.h"
59 #include "tensorflow/core/public/version.h"
60 
61 // The implementation below is at the top level instead of the
62 // brain namespace because we are defining 'extern "C"' functions.
63 using tensorflow::AllocationDescription;
64 using tensorflow::DataType;
65 using tensorflow::Graph;
66 using tensorflow::GraphDef;
67 using tensorflow::mutex_lock;
68 using tensorflow::NameRangeMap;
69 using tensorflow::NameRangesForNode;
70 using tensorflow::NewSession;
71 using tensorflow::Node;
72 using tensorflow::NodeBuilder;
73 using tensorflow::NodeDef;
74 using tensorflow::OpDef;
75 using tensorflow::OpRegistry;
76 using tensorflow::PartialTensorShape;
77 using tensorflow::RunMetadata;
78 using tensorflow::RunOptions;
79 using tensorflow::Session;
80 using tensorflow::Status;
81 using tensorflow::string;
82 using tensorflow::Tensor;
83 using tensorflow::TensorBuffer;
84 using tensorflow::TensorId;
85 using tensorflow::TensorShape;
86 using tensorflow::TensorShapeProto;
87 using tensorflow::VersionDef;
88 using tensorflow::error::Code;
89 using tensorflow::errors::FailedPrecondition;
90 using tensorflow::errors::InvalidArgument;
91 using tensorflow::gtl::ArraySlice;
92 using tensorflow::strings::StrCat;
93 
94 extern "C" {
95 
96 // --------------------------------------------------------------------------
TF_Version()97 const char* TF_Version() { return TF_VERSION_STRING; }
98 
99 // --------------------------------------------------------------------------
TF_DataTypeSize(TF_DataType dt)100 size_t TF_DataTypeSize(TF_DataType dt) {
101   return static_cast<size_t>(
102       tensorflow::DataTypeSize(static_cast<DataType>(dt)));
103 }
104 
105 // --------------------------------------------------------------------------
106 
TF_NewStatus()107 TF_Status* TF_NewStatus() { return new TF_Status; }
108 
TF_DeleteStatus(TF_Status * s)109 void TF_DeleteStatus(TF_Status* s) { delete s; }
110 
TF_SetStatus(TF_Status * s,TF_Code code,const char * msg)111 void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
112   if (code == TF_OK) {
113     s->status = Status::OK();
114     return;
115   }
116   s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
117 }
118 
TF_GetCode(const TF_Status * s)119 TF_Code TF_GetCode(const TF_Status* s) {
120   return static_cast<TF_Code>(s->status.code());
121 }
122 
TF_Message(const TF_Status * s)123 const char* TF_Message(const TF_Status* s) {
124   return s->status.error_message().c_str();
125 }
126 
127 // --------------------------------------------------------------------------
128 
129 namespace {
130 class TF_ManagedBuffer : public TensorBuffer {
131  public:
132   void* data_;
133   size_t len_;
134   void (*deallocator_)(void* data, size_t len, void* arg);
135   void* deallocator_arg_;
136 
~TF_ManagedBuffer()137   ~TF_ManagedBuffer() override {
138     (*deallocator_)(data_, len_, deallocator_arg_);
139   }
140 
data() const141   void* data() const override { return data_; }
size() const142   size_t size() const override { return len_; }
root_buffer()143   TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const144   void FillAllocationDescription(AllocationDescription* proto) const override {
145     tensorflow::int64 rb = size();
146     proto->set_requested_bytes(rb);
147     proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
148   }
149 
150   // Prevents input forwarding from mutating this buffer.
OwnsMemory() const151   bool OwnsMemory() const override { return false; }
152 };
153 
allocate_tensor(const char * operation,size_t len)154 void* allocate_tensor(const char* operation, size_t len) {
155   void* data =
156       tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
157   if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
158     tensorflow::LogMemory::RecordRawAllocation(
159         operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
160         len, data, tensorflow::cpu_allocator());
161   }
162   return data;
163 }
164 
deallocate_buffer(void * data,size_t len,void * arg)165 void deallocate_buffer(void* data, size_t len, void* arg) {
166   if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
167     tensorflow::LogMemory::RecordRawDeallocation(
168         "TensorFlow C Api",
169         tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
170         tensorflow::cpu_allocator(), false);
171   }
172   tensorflow::cpu_allocator()->DeallocateRaw(data);
173 }
174 
175 }  // namespace
176 
~TF_Tensor()177 TF_Tensor::~TF_Tensor() { buffer->Unref(); }
178 
TF_AllocateTensor(TF_DataType dtype,const int64_t * dims,int num_dims,size_t len)179 TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
180                              int num_dims, size_t len) {
181   void* data = allocate_tensor("TF_AllocateTensor", len);
182   return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
183                       nullptr);
184 }
185 
TF_NewTensor(TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)186 TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
187                         void* data, size_t len,
188                         void (*deallocator)(void* data, size_t len, void* arg),
189                         void* deallocator_arg) {
190   std::vector<tensorflow::int64> dimvec(num_dims);
191   for (int i = 0; i < num_dims; ++i) {
192     dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
193   }
194 
195   TF_ManagedBuffer* buf = new TF_ManagedBuffer;
196   buf->len_ = len;
197   if (dtype != TF_STRING && dtype != TF_RESOURCE &&
198       tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
199       reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
200     // TF_STRING and TF_RESOURCE tensors have a different representation in
201     // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
202     // (any alignment requirements will be taken care of by TF_TensorToTensor
203     // and TF_TensorFromTensor).
204     //
205     // Other types have the same representation, so copy only if it is safe to
206     // do so.
207     buf->data_ = allocate_tensor("TF_NewTensor", len);
208     std::memcpy(buf->data_, data, len);
209     buf->deallocator_ = deallocate_buffer;
210     buf->deallocator_arg_ = nullptr;
211     // Free the original buffer.
212     deallocator(data, len, deallocator_arg);
213   } else {
214     buf->data_ = data;
215     buf->deallocator_ = deallocator;
216     buf->deallocator_arg_ = deallocator_arg;
217   }
218   TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
219   size_t elem_size = TF_DataTypeSize(dtype);
220   if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
221     delete ret;
222     return nullptr;
223   }
224   return ret;
225 }
226 
TF_TensorMaybeMove(TF_Tensor * tensor)227 TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
228   // It is safe to move the Tensor if and only if we own the unique reference to
229   // it. In that case, we might as well not delete and reallocate, but a future
230   // implementation might need to do so.
231   TensorBuffer* buf = tensor->buffer;
232   if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
233       buf->OwnsMemory()) {
234     return tensor;
235   }
236   return nullptr;
237 }
238 
TF_DeleteTensor(TF_Tensor * t)239 void TF_DeleteTensor(TF_Tensor* t) { delete t; }
240 
TF_TensorType(const TF_Tensor * t)241 TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
TF_NumDims(const TF_Tensor * t)242 int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
TF_Dim(const TF_Tensor * t,int dim_index)243 int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
244   return static_cast<int64_t>(t->shape.dim_size(dim_index));
245 }
TF_TensorByteSize(const TF_Tensor * t)246 size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
TF_TensorData(const TF_Tensor * t)247 void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
248 
249 // --------------------------------------------------------------------------
TF_StringEncode(const char * src,size_t src_len,char * dst,size_t dst_len,TF_Status * status)250 size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
251                        size_t dst_len, TF_Status* status) {
252   const size_t sz = TF_StringEncodedSize(src_len);
253   if (sz < src_len) {
254     status->status = InvalidArgument("src string is too large to encode");
255     return 0;
256   }
257   if (dst_len < sz) {
258     status->status =
259         InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
260                         src_len, "-byte string");
261     return 0;
262   }
263   dst = tensorflow::core::EncodeVarint64(dst, src_len);
264   memcpy(dst, src, src_len);
265   return sz;
266 }
267 
TF_StringDecode_Impl(const char * src,size_t src_len,const char ** dst,size_t * dst_len)268 static Status TF_StringDecode_Impl(const char* src, size_t src_len,
269                                    const char** dst, size_t* dst_len) {
270   tensorflow::uint64 len64 = 0;
271   const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
272   if (p == nullptr) {
273     return InvalidArgument("invalid string encoding or truncated src buffer");
274   }
275   if (len64 > std::numeric_limits<size_t>::max()) {
276     return InvalidArgument("encoded string is ", len64,
277                            "-bytes, which is too large for this architecture");
278   }
279   *dst = p;
280   *dst_len = static_cast<size_t>(len64);
281   return Status::OK();
282 }
283 
TF_StringDecode(const char * src,size_t src_len,const char ** dst,size_t * dst_len,TF_Status * status)284 size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
285                        size_t* dst_len, TF_Status* status) {
286   status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
287   if (!status->status.ok()) return 0;
288   return static_cast<size_t>(*dst - src) + *dst_len;
289 }
290 
TF_StringEncodedSize(size_t len)291 size_t TF_StringEncodedSize(size_t len) {
292   return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
293 }
294 
295 // --------------------------------------------------------------------------
TF_NewSessionOptions()296 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)297 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
298 
TF_SetTarget(TF_SessionOptions * options,const char * target)299 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
300   options->options.target = target;
301 }
302 
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)303 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
304                   size_t proto_len, TF_Status* status) {
305   if (!options->options.config.ParseFromArray(proto, proto_len)) {
306     status->status = InvalidArgument("Unparseable ConfigProto");
307   }
308 }
309 // --------------------------------------------------------------------------
TF_NewBuffer()310 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
311 
TF_NewBufferFromString(const void * proto,size_t proto_len)312 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
313   void* copy = tensorflow::port::Malloc(proto_len);
314   memcpy(copy, proto, proto_len);
315 
316   TF_Buffer* buf = new TF_Buffer;
317   buf->data = copy;
318   buf->length = proto_len;
319   buf->data_deallocator = [](void* data, size_t length) {
320     tensorflow::port::Free(data);
321   };
322   return buf;
323 }
324 
TF_DeleteBuffer(TF_Buffer * buffer)325 void TF_DeleteBuffer(TF_Buffer* buffer) {
326   if (buffer->data_deallocator != nullptr) {
327     (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
328                                 buffer->length);
329   }
330   delete buffer;
331 }
332 
TF_GetBuffer(TF_Buffer * buffer)333 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
334 
335 // --------------------------------------------------------------------------
336 
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)337 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
338                                               TF_Status* status) {
339   Session* session;
340   status->status = NewSession(opt->options, &session);
341   if (status->status.ok()) {
342     return new TF_DeprecatedSession({session});
343   } else {
344     DCHECK_EQ(nullptr, session);
345     return nullptr;
346   }
347 }
348 
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)349 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
350   status->status = s->session->Close();
351 }
352 
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)353 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
354   status->status = Status::OK();
355   delete s->session;
356   delete s;
357 }
358 
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)359 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
360                     size_t proto_len, TF_Status* status) {
361   GraphDef g;
362   if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
363     status->status = InvalidArgument("Invalid GraphDef");
364     return;
365   }
366   status->status = s->session->Extend(g);
367 }
368 
DeleteArray(void * data,size_t size,void * arg)369 static void DeleteArray(void* data, size_t size, void* arg) {
370   DCHECK_EQ(data, arg);
371   delete[] reinterpret_cast<char*>(arg);
372 }
373 
374 }  // end extern "C"
375 
376 namespace tensorflow {
377 namespace {
378 
379 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)380 void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
381                      int ncontainers, TF_Status* status) {
382   std::vector<string> container_names(ncontainers);
383   for (int i = 0; i < ncontainers; ++i) {
384     container_names[i] = containers[i];
385   }
386 
387   status->status = Reset(opt->options, container_names);
388 }
389 
390 // This traverses the specified nodes in topological order to verify there are
391 // no cycles. Starting with inputless nodes, it visits nodes whose inputs have
392 // all been visited, and counts the total number of visited nodes. If there is a
393 // cycle, nodes in the cycle will never be visited, and the visited count will
394 // be less than the total node count.
ValidateNoCycles(const Graph & g)395 Status ValidateNoCycles(const Graph& g) {
396   // TODO(nolivia): check this on a subset of the graph instead of all of it.
397   // A node is ready when all of its inputs have been visited.
398   std::vector<const Node*> ready;
399   std::vector<int> pending_count(g.num_node_ids(), 0);
400 
401   for (int i = 0; i < g.num_node_ids(); ++i) {
402     const Node* n = g.FindNodeId(i);
403     if (n == nullptr) continue;
404     pending_count[i] = n->in_edges().size();
405     if (n->IsMerge()) {
406       // While-loop cycles are legal cycles so we manually adjust the
407       // pending_count to make sure that the loop is visited.
408       for (const Edge* e : n->in_edges()) {
409         if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
410           pending_count[i]--;
411         }
412       }
413     }
414     if (pending_count[i] == 0) {
415       ready.push_back(n);
416     }
417   }
418 
419   int processed = 0;
420   while (!ready.empty()) {
421     const Node* node = ready.back();
422     ready.pop_back();
423     ++processed;
424 
425     for (const Edge* out : node->out_edges()) {
426       const int output_id = out->dst()->id();
427       pending_count[output_id]--;
428       if (pending_count[output_id] == 0) {
429         ready.push_back(out->dst());
430       }
431     }
432   }
433 
434   if (processed < g.num_nodes()) {
435     std::vector<string> nodes_in_cycle;
436     for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
437          ++i) {
438       if (pending_count[i] != 0) {
439         nodes_in_cycle.push_back(g.FindNodeId(i)->name());
440       }
441     }
442     return errors::InvalidArgument(
443         "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
444         " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
445   }
446   return Status::OK();
447 }
448 }  // namespace
449 }  // namespace tensorflow
450 
451 extern "C" {
452 
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)453 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
454               int ncontainers, TF_Status* status) {
455   tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
456 }
457 
458 }  // end extern "C"
459 
460 namespace tensorflow {
461 
TF_TensorToTensor(const TF_Tensor * src,Tensor * dst)462 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
463   if (src->dtype == TF_RESOURCE) {
464     if (src->shape.dims() != 0) {
465       return InvalidArgument(
466           "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
467           "shape ",
468           src->shape.DebugString());
469     }
470     *dst = Tensor(DT_RESOURCE, src->shape);
471     if (!dst->scalar<ResourceHandle>()().ParseFromString(
472             string(static_cast<const char*>(TF_TensorData(src)),
473                    TF_TensorByteSize(src)))) {
474       return InvalidArgument(
475           "Malformed TF_RESOUCE tensor: unable to parse resource handle");
476     }
477     return Status::OK();
478   }
479   if (src->dtype != TF_STRING) {
480     *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
481     return Status::OK();
482   }
483   // TF_STRING tensors require copying since Tensor class expects a sequence of
484   // string objects.
485   const tensorflow::int64 num_elements = src->shape.num_elements();
486   const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
487   const size_t src_size = TF_TensorByteSize(src);
488   if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
489       num_elements) {
490     return InvalidArgument(
491         "Malformed TF_STRING tensor; too short to hold number of elements");
492   }
493   const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
494   const char* limit = input + src_size;
495 
496   *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
497   auto dstarray = dst->flat<string>();
498   for (tensorflow::int64 i = 0; i < num_elements; ++i) {
499     tensorflow::uint64 offset =
500         reinterpret_cast<const tensorflow::uint64*>(input)[i];
501     if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
502       return InvalidArgument("Malformed TF_STRING tensor; element ", i,
503                              " out of range");
504     }
505     size_t len;
506     const char* p;
507     const char* srcp = data_start + offset;
508     Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
509     if (!status.ok()) return status;
510     dstarray(i).assign(p, len);
511   }
512   return Status::OK();
513 }
514 
515 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
516 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const TensorShape & shape)517 static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
518   static char empty;
519   tensorflow::int64 nelems = 1;
520   std::vector<tensorflow::int64> dims;
521   for (int i = 0; i < shape.dims(); ++i) {
522     dims.push_back(shape.dim_size(i));
523     nelems *= shape.dim_size(i);
524   }
525   CHECK_EQ(nelems, 0);
526   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
527                 "64-bit int types should match in size");
528   return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
529                       shape.dims(), reinterpret_cast<void*>(&empty), 0,
530                       [](void*, size_t, void*) {}, nullptr);
531 }
532 
533 // Non-static for testing.
TF_TensorFromTensor(const tensorflow::Tensor & src,TF_Status * status)534 TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
535                                TF_Status* status) {
536   if (!src.IsInitialized()) {
537     status->status = FailedPrecondition(
538         "attempt to use a tensor with an uninitialized value");
539     return nullptr;
540   }
541   if (src.NumElements() == 0) {
542     return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
543   }
544   if (src.dtype() == DT_RESOURCE) {
545     if (src.shape().dims() != 0) {
546       status->status = InvalidArgument(
547           "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
548           src.shape().DebugString(),
549           "). Please file a bug at "
550           "https://github.com/tensorflow/tensorflow/issues/new, "
551           "ideally with a "
552           "short code snippet that reproduces this error.");
553       return nullptr;
554     }
555     const string str = src.scalar<ResourceHandle>()().SerializeAsString();
556     TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
557     std::memcpy(TF_TensorData(t), str.c_str(), str.size());
558     return t;
559   }
560   if (src.dtype() != DT_STRING) {
561     TensorBuffer* buf = TensorCApi::Buffer(src);
562     buf->Ref();
563     return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
564                          buf};
565   }
566   // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
567   // encoded sequence of strings.
568 
569   // Compute bytes needed for encoding.
570   size_t size = 0;
571   const auto& srcarray = src.flat<string>();
572   for (int i = 0; i < srcarray.size(); ++i) {
573     const string& s = srcarray(i);
574     // uint64 starting_offset, TF_StringEncode-d string.
575     size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
576   }
577 
578   // Encode all strings.
579   char* base = new char[size];
580   char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
581   char* dst = data_start;  // Where next string is encoded.
582   size_t dst_len = size - static_cast<size_t>(data_start - base);
583   tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
584   for (int i = 0; i < srcarray.size(); ++i) {
585     *offsets = (dst - data_start);
586     offsets++;
587     const string& s = srcarray(i);
588     size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
589     if (!status->status.ok()) {
590       status->status = InvalidArgument(
591           "invalid string tensor encoding (string #", i, " of ",
592           srcarray.size(), "): ", status->status.error_message());
593       delete[] base;
594       return nullptr;
595     }
596     dst += consumed;
597     dst_len -= consumed;
598   }
599   if (dst != base + size) {
600     status->status = InvalidArgument(
601         "invalid string tensor encoding (decoded ", (dst - base),
602         " bytes, but the tensor is encoded in ", size, " bytes");
603     delete[] base;
604     return nullptr;
605   }
606 
607   auto dims = src.shape().dim_sizes();
608   std::vector<tensorflow::int64> dimvec(dims.size());
609   for (size_t i = 0; i < dims.size(); ++i) {
610     dimvec[i] = dims[i];
611   }
612   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
613                 "64-bit int types should match in size");
614   return TF_NewTensor(TF_STRING,
615                       reinterpret_cast<const int64_t*>(dimvec.data()),
616                       dimvec.size(), base, size, DeleteArray, base);
617 }
618 
MessageToBuffer(const tensorflow::protobuf::Message & in,TF_Buffer * out)619 Status MessageToBuffer(const tensorflow::protobuf::Message& in,
620                        TF_Buffer* out) {
621   if (out->data != nullptr) {
622     return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
623   }
624   const size_t proto_size = in.ByteSizeLong();
625   void* buf = tensorflow::port::Malloc(proto_size);
626   if (buf == nullptr) {
627     return tensorflow::errors::ResourceExhausted(
628         "Failed to allocate memory to serialize message of type '",
629         in.GetTypeName(), "' and size ", proto_size);
630   }
631   in.SerializeToArray(buf, proto_size);
632   out->data = buf;
633   out->length = proto_size;
634   out->data_deallocator = [](void* data, size_t length) {
635     tensorflow::port::Free(data);
636   };
637   return Status::OK();
638 }
639 
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)640 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
641                     const char* mutation_type)
642     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
643   // If any session has already run this node_id, mark this session as
644   // unrunnable.
645   for (auto it : graph->sessions) {
646     if (it.first->last_num_graph_nodes > op.node.id()) {
647       it.second = FailedPrecondition(
648           "Operation '", op.node.DebugString(), "' was changed by ",
649           mutation_type,
650           " after it was run by a session. Nodes can be mutated "
651           "only before they are executed by a session. Either don't modify "
652           "nodes after running them or create a new session.");
653     }
654   }
655 }
656 
657 namespace {
658 
659 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)660 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
661     tensorflow::shape_inference::InferenceContext* ic, int num_dims,
662     const int64_t* dims) {
663   if (num_dims != -1) {
664     std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
665     dim_vec.reserve(num_dims);
666     for (int i = 0; i < num_dims; ++i) {
667       dim_vec.push_back(ic->MakeDim(dims[i]));
668     }
669     return ic->MakeShape(dim_vec);
670   } else {
671     return ic->UnknownShape();
672   }
673 }
674 
675 }  // namespace
676 
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)677 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
678                                            int num_shapes_and_types,
679                                            const int64_t** shapes,
680                                            const int* ranks,
681                                            const TF_DataType* types,
682                                            TF_Status* status) {
683   Node* node = &output.oper->node;
684 
685   mutex_lock l(graph->mu);
686   tensorflow::shape_inference::InferenceContext* ic =
687       graph->refiner.GetContext(node);
688   if (ic == nullptr) {
689     status->status =
690         InvalidArgument("Node ", node->name(), " was not found in the graph");
691     return;
692   }
693 
694   auto shape_and_type_vec =
695       std::vector<tensorflow::shape_inference::ShapeAndType>(
696           num_shapes_and_types);
697   for (int i = 0; i < num_shapes_and_types; ++i) {
698     tensorflow::shape_inference::ShapeHandle shape_handle =
699         ShapeHandleFromDims(ic, ranks[i], shapes[i]);
700     shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
701         shape_handle, static_cast<DataType>(types[i]));
702   }
703 
704   ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
705 }
706 
707 // Helpers for loading a TensorFlow plugin (a .so file).
708 Status LoadLibrary(const char* library_filename, void** result,
709                    const void** buf, size_t* len);
710 
711 }  // namespace tensorflow
712 
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)713 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
714                          TF_Status* status) {
715   status->status = Status::OK();
716   for (int i = 0; i < noutputs; ++i) {
717     c_outputs[i] = nullptr;
718   }
719 }
720 
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)721 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
722                           std::vector<std::pair<string, Tensor>>* input_pairs,
723                           TF_Status* status) {
724   const int ninputs = input_pairs->size();
725   for (int i = 0; i < ninputs; ++i) {
726     status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
727     if (!status->status.ok()) return false;
728   }
729   return true;
730 }
731 
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)732 static void TF_Run_Helper(
733     Session* session, const char* handle, const TF_Buffer* run_options,
734     // Input tensors
735     const std::vector<std::pair<string, Tensor>>& input_pairs,
736     // Output tensors
737     const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
738     // Target nodes
739     const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
740     TF_Status* status) {
741   const int noutputs = output_tensor_names.size();
742   std::vector<Tensor> outputs(noutputs);
743   Status result;
744 
745   if (handle == nullptr) {
746     RunOptions run_options_proto;
747     if (run_options != nullptr && !run_options_proto.ParseFromArray(
748                                       run_options->data, run_options->length)) {
749       status->status = InvalidArgument("Unparseable RunOptions proto");
750       return;
751     }
752     if (run_metadata != nullptr && run_metadata->data != nullptr) {
753       status->status =
754           InvalidArgument("Passing non-empty run_metadata is invalid.");
755       return;
756     }
757 
758     RunMetadata run_metadata_proto;
759     result = session->Run(run_options_proto, input_pairs, output_tensor_names,
760                           target_oper_names, &outputs, &run_metadata_proto);
761 
762     // Serialize back to upstream client, who now owns the new buffer
763     if (run_metadata != nullptr) {
764       status->status = MessageToBuffer(run_metadata_proto, run_metadata);
765       if (!status->status.ok()) return;
766     }
767   } else {
768     // NOTE(zongheng): PRun does not support RunOptions yet.
769     result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
770   }
771   if (!result.ok()) {
772     status->status = result;
773     return;
774   }
775 
776   // Store results in c_outputs[]
777   for (int i = 0; i < noutputs; ++i) {
778     const Tensor& src = outputs[i];
779     if (!src.IsInitialized() || src.NumElements() == 0) {
780       c_outputs[i] =
781           EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
782       continue;
783     }
784     c_outputs[i] = TF_TensorFromTensor(src, status);
785     if (!status->status.ok()) return;
786   }
787 }
788 
789 extern "C" {
790 
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)791 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
792             // Input tensors
793             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
794             // Output tensors
795             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
796             // Target nodes
797             const char** c_target_oper_names, int ntargets,
798             TF_Buffer* run_metadata, TF_Status* status) {
799   TF_Run_Setup(noutputs, c_outputs, status);
800   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
801   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
802   for (int i = 0; i < ninputs; ++i) {
803     input_pairs[i].first = c_input_names[i];
804   }
805   std::vector<string> output_names(noutputs);
806   for (int i = 0; i < noutputs; ++i) {
807     output_names[i] = c_output_names[i];
808   }
809   std::vector<string> target_oper_names(ntargets);
810   for (int i = 0; i < ntargets; ++i) {
811     target_oper_names[i] = c_target_oper_names[i];
812   }
813   TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
814                 c_outputs, target_oper_names, run_metadata, status);
815 }
816 
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)817 void TF_PRunSetup(TF_DeprecatedSession* s,
818                   // Input names
819                   const char** c_input_names, int ninputs,
820                   // Output names
821                   const char** c_output_names, int noutputs,
822                   // Target nodes
823                   const char** c_target_oper_names, int ntargets,
824                   const char** handle, TF_Status* status) {
825   *handle = nullptr;
826 
827   std::vector<string> input_names(ninputs);
828   std::vector<string> output_names(noutputs);
829   std::vector<string> target_oper_names(ntargets);
830   for (int i = 0; i < ninputs; ++i) {
831     input_names[i] = c_input_names[i];
832   }
833   for (int i = 0; i < noutputs; ++i) {
834     output_names[i] = c_output_names[i];
835   }
836   for (int i = 0; i < ntargets; ++i) {
837     target_oper_names[i] = c_target_oper_names[i];
838   }
839   string new_handle;
840   status->status = s->session->PRunSetup(input_names, output_names,
841                                          target_oper_names, &new_handle);
842   if (status->status.ok()) {
843     char* buf = new char[new_handle.size() + 1];
844     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
845     *handle = buf;
846   }
847 }
848 
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)849 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
850              // Input tensors
851              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
852              // Output tensors
853              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
854              // Target nodes
855              const char** c_target_oper_names, int ntargets,
856              TF_Status* status) {
857   TF_Run_Setup(noutputs, c_outputs, status);
858   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
859   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
860   for (int i = 0; i < ninputs; ++i) {
861     input_pairs[i].first = c_input_names[i];
862   }
863 
864   std::vector<string> output_names(noutputs);
865   for (int i = 0; i < noutputs; ++i) {
866     output_names[i] = c_output_names[i];
867   }
868   std::vector<string> target_oper_names(ntargets);
869   for (int i = 0; i < ntargets; ++i) {
870     target_oper_names[i] = c_target_oper_names[i];
871   }
872   TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
873                 c_outputs, target_oper_names, nullptr, status);
874 }
875 
TF_LoadLibrary(const char * library_filename,TF_Status * status)876 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
877   TF_Library* lib_handle = new TF_Library;
878   status->status = tensorflow::LoadLibrary(
879       library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
880       &lib_handle->op_list.length);
881   if (!status->status.ok()) {
882     delete lib_handle;
883     return nullptr;
884   }
885   return lib_handle;
886 }
887 
TF_GetOpList(TF_Library * lib_handle)888 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
889 
TF_DeleteLibraryHandle(TF_Library * lib_handle)890 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
891   tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
892   delete lib_handle;
893 }
894 
TF_GetAllOpList()895 TF_Buffer* TF_GetAllOpList() {
896   std::vector<tensorflow::OpDef> op_defs;
897   tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
898   tensorflow::OpList op_list;
899   for (const auto& op : op_defs) {
900     *(op_list.add_op()) = op;
901   }
902   TF_Buffer* ret = TF_NewBuffer();
903   TF_CHECK_OK(MessageToBuffer(op_list, ret));
904   return ret;
905 }
906 
907 // --------------------------------------------------------------------------
908 // ListDevices & SessionListDevices API
909 
TF_DeleteDeviceList(TF_DeviceList * s)910 void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
911 
TF_SessionListDevices(TF_Session * session,TF_Status * status)912 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
913   TF_DeviceList* response = new TF_DeviceList;
914   status->status = session->session->ListDevices(&response->response);
915   return response;
916 }
917 
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)918 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
919                                                TF_Status* status) {
920   TF_DeviceList* response = new TF_DeviceList;
921   status->status = session->session->ListDevices(&response->response);
922   return response;
923 }
924 
TF_DeviceListCount(const TF_DeviceList * list)925 int TF_DeviceListCount(const TF_DeviceList* list) {
926   return list->response.size();
927 }
928 
929 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
930   return_type method_name(const TF_DeviceList* list, const int index,     \
931                           TF_Status* status) {                            \
932     if (list == nullptr) {                                                \
933       status->status = InvalidArgument("list is null!");                  \
934       return err_val;                                                     \
935     }                                                                     \
936     if (index < 0 || index >= list->response.size()) {                    \
937       status->status = InvalidArgument("index out of bounds");            \
938       return err_val;                                                     \
939     }                                                                     \
940     status->status = Status::OK();                                        \
941     return list->response[index].accessor;                                \
942   }
943 
944 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
945 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
946                      nullptr);
947 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
948 
949 #undef TF_DEVICELIST_METHOD
950 
951 }  // end extern "C"
952 
953 // --------------------------------------------------------------------------
954 // New Graph and Session API
955 
956 // Helper functions -----------------------------------------------------------
957 
958 namespace {
959 
ToOperation(Node * node)960 TF_Operation* ToOperation(Node* node) {
961   return static_cast<TF_Operation*>(static_cast<void*>(node));
962 }
963 
OutputName(const TF_Output & output)964 string OutputName(const TF_Output& output) {
965   return StrCat(output.oper->node.name(), ":", output.index);
966 }
967 
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)968 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
969                                           const char* attr_name,
970                                           TF_Status* status) {
971   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
972   if (attr == nullptr) {
973     status->status = InvalidArgument("Operation '", oper->node.name(),
974                                      "' has no attr named '", attr_name, "'.");
975   }
976   return attr;
977 }
978 
ToTensorId(const TF_Output & output)979 TensorId ToTensorId(const TF_Output& output) {
980   return TensorId(output.oper->node.name(), output.index);
981 }
982 
983 #ifndef __ANDROID__
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)984 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
985                                                      int n) {
986   std::vector<tensorflow::Output> outputs(n);
987   for (int i = 0; i < n; ++i) {
988     outputs[i] =
989         tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
990   }
991   return outputs;
992 }
993 
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)994 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
995                           TF_Output* tf_outputs) {
996   for (int i = 0; i < outputs.size(); i++) {
997     tf_outputs[i].oper = ToOperation(outputs[i].node());
998     tf_outputs[i].index = outputs[i].index();
999   }
1000 }
1001 #endif  // __ANDROID__
1002 
1003 }  // namespace
1004 
1005 // Shape functions -----------------------------------------------------------
1006 
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)1007 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1008                             const int64_t* dims, const int num_dims,
1009                             TF_Status* status) {
1010   Node* node = &output.oper->node;
1011 
1012   mutex_lock l(graph->mu);
1013   tensorflow::shape_inference::InferenceContext* ic =
1014       graph->refiner.GetContext(node);
1015   if (ic == nullptr) {
1016     status->status =
1017         InvalidArgument("Node ", node->name(), " was not found in the graph");
1018     return;
1019   }
1020   tensorflow::shape_inference::ShapeHandle new_shape =
1021       tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1022   status->status = graph->refiner.SetShape(node, output.index, new_shape);
1023 }
1024 
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)1025 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1026                              TF_Status* status) {
1027   Node* node = &output.oper->node;
1028 
1029   mutex_lock l(graph->mu);
1030   tensorflow::shape_inference::InferenceContext* ic =
1031       graph->refiner.GetContext(node);
1032   if (ic == nullptr) {
1033     status->status =
1034         InvalidArgument("Node ", node->name(), " was not found in the graph");
1035     return -1;
1036   }
1037 
1038   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1039 
1040   // Unknown rank means the number of dimensions is -1.
1041   if (!ic->RankKnown(shape)) {
1042     return -1;
1043   }
1044 
1045   return ic->Rank(shape);
1046 }
1047 
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)1048 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1049                             int num_dims, TF_Status* status) {
1050   Node* node = &output.oper->node;
1051 
1052   mutex_lock l(graph->mu);
1053   tensorflow::shape_inference::InferenceContext* ic =
1054       graph->refiner.GetContext(node);
1055   if (ic == nullptr) {
1056     status->status =
1057         InvalidArgument("Node ", node->name(), " was not found in the graph");
1058     return;
1059   }
1060 
1061   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1062 
1063   int rank = -1;
1064   if (ic->RankKnown(shape)) {
1065     rank = ic->Rank(shape);
1066   }
1067 
1068   if (num_dims != rank) {
1069     status->status = InvalidArgument("Expected rank is ", num_dims,
1070                                      " but actual rank is ", rank);
1071     return;
1072   }
1073 
1074   if (num_dims == 0) {
1075     // Output shape is a scalar.
1076     return;
1077   }
1078 
1079   // Rank is greater than 0, so fill in the values, if known, and
1080   // -1 for unknown values.
1081   for (int i = 0; i < num_dims; ++i) {
1082     auto dim = ic->Dim(shape, i);
1083     tensorflow::int64 value = -1;
1084     if (ic->ValueKnown(dim)) {
1085       value = ic->Value(dim);
1086     }
1087     dims[i] = value;
1088   }
1089 }
1090 
1091 // TF_OperationDescription functions ------------------------------------------
1092 
1093 extern "C" {
1094 
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)1095 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1096                                                       const char* op_type,
1097                                                       const char* oper_name)
1098     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1099   return new TF_OperationDescription(graph, op_type, oper_name);
1100 }
1101 
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)1102 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1103                                          const char* oper_name) {
1104   mutex_lock l(graph->mu);
1105   return TF_NewOperationLocked(graph, op_type, oper_name);
1106 }
1107 
TF_SetDevice(TF_OperationDescription * desc,const char * device)1108 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1109   desc->node_builder.Device(device);
1110 }
1111 
TF_AddInput(TF_OperationDescription * desc,TF_Output input)1112 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1113   desc->node_builder.Input(&input.oper->node, input.index);
1114 }
1115 
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)1116 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1117                      int num_inputs) {
1118   std::vector<NodeBuilder::NodeOut> input_list;
1119   input_list.reserve(num_inputs);
1120   for (int i = 0; i < num_inputs; ++i) {
1121     input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1122   }
1123   desc->node_builder.Input(input_list);
1124 }
1125 
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)1126 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1127   desc->node_builder.ControlInput(&input->node);
1128 }
1129 
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)1130 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1131   desc->colocation_constraints.emplace(
1132       StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1133 }
1134 
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)1135 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1136                       const void* value, size_t length) {
1137   tensorflow::StringPiece s(static_cast<const char*>(value), length);
1138   desc->node_builder.Attr(attr_name, s);
1139 }
1140 
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1141 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1142                           const void* const* values, const size_t* lengths,
1143                           int num_values) {
1144   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1145     desc->colocation_constraints.clear();
1146     for (int i = 0; i < num_values; ++i) {
1147       desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1148                                            lengths[i]);
1149     }
1150   } else {
1151     std::vector<tensorflow::StringPiece> v;
1152     v.reserve(num_values);
1153     for (int i = 0; i < num_values; ++i) {
1154       v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1155     }
1156     desc->node_builder.Attr(attr_name, v);
1157   }
1158 }
1159 
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)1160 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1161                    int64_t value) {
1162   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1163                 "64-bit int types should match in size");
1164   desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1165 }
1166 
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)1167 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1168                        const int64_t* values, int num_values) {
1169   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1170                 "64-bit int types should match in size");
1171   desc->node_builder.Attr(
1172       attr_name,
1173       ArraySlice<const tensorflow::int64>(
1174           reinterpret_cast<const tensorflow::int64*>(values), num_values));
1175 }
1176 
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)1177 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1178                      float value) {
1179   desc->node_builder.Attr(attr_name, value);
1180 }
1181 
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)1182 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1183                          const float* values, int num_values) {
1184   desc->node_builder.Attr(attr_name,
1185                           ArraySlice<const float>(values, num_values));
1186 }
1187 
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)1188 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1189                     unsigned char value) {
1190   desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1191 }
1192 
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)1193 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1194                         const unsigned char* values, int num_values) {
1195   std::unique_ptr<bool[]> b(new bool[num_values]);
1196   for (int i = 0; i < num_values; ++i) {
1197     b[i] = values[i];
1198   }
1199   desc->node_builder.Attr(attr_name,
1200                           ArraySlice<const bool>(b.get(), num_values));
1201 }
1202 
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)1203 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1204                     TF_DataType value) {
1205   desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1206 }
1207 
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)1208 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1209                         const TF_DataType* values, int num_values) {
1210   desc->node_builder.Attr(
1211       attr_name, ArraySlice<const DataType>(
1212                      reinterpret_cast<const DataType*>(values), num_values));
1213 }
1214 
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)1215 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1216                         const char* value, size_t length) {
1217   tensorflow::NameAttrList func_name;
1218   func_name.set_name(std::string(value, value + length));
1219   desc->node_builder.Attr(attr_name, func_name);
1220 }
1221 
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)1222 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1223                      const int64_t* dims, int num_dims) {
1224   PartialTensorShape shape;
1225   if (num_dims >= 0) {
1226     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1227                   "64-bit int types should match in size");
1228     shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1229         reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1230   }
1231   desc->node_builder.Attr(attr_name, shape);
1232 }
1233 
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)1234 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1235                          const int64_t* const* dims, const int* num_dims,
1236                          int num_shapes) {
1237   std::vector<PartialTensorShape> shapes;
1238   shapes.reserve(num_shapes);
1239   for (int i = 0; i < num_shapes; ++i) {
1240     if (num_dims[i] < 0) {
1241       shapes.emplace_back();
1242     } else {
1243       static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1244                     "64-bit int types should match in size");
1245       shapes.emplace_back(ArraySlice<tensorflow::int64>(
1246           reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1247     }
1248   }
1249   desc->node_builder.Attr(attr_name, shapes);
1250 }
1251 
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1252 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1253                                 const char* attr_name, const void* proto,
1254                                 size_t proto_len, TF_Status* status) {
1255   // shape.ParseFromArray takes an int as length, this function takes size_t,
1256   // make sure there is no information loss.
1257   if (proto_len > std::numeric_limits<int>::max()) {
1258     status->status = InvalidArgument(
1259         "proto_len (", proto_len,
1260         " bytes) is too large to be parsed by the protocol buffer library");
1261     return;
1262   }
1263   TensorShapeProto shape;
1264   if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1265     desc->node_builder.Attr(attr_name, shape);
1266     status->status = Status::OK();
1267   } else {
1268     status->status = InvalidArgument("Unparseable TensorShapeProto");
1269   }
1270 }
1271 
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)1272 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1273                                     const char* attr_name,
1274                                     const void* const* protos,
1275                                     const size_t* proto_lens, int num_shapes,
1276                                     TF_Status* status) {
1277   std::vector<TensorShapeProto> shapes;
1278   shapes.resize(num_shapes);
1279   for (int i = 0; i < num_shapes; ++i) {
1280     if (proto_lens[i] > std::numeric_limits<int>::max()) {
1281       status->status = InvalidArgument(
1282           "length of element ", i, " in the list (", proto_lens[i],
1283           " bytes) is too large to be parsed by the protocol buffer library");
1284       return;
1285     }
1286     if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1287       status->status =
1288           InvalidArgument("Unparseable TensorShapeProto at index ", i);
1289       return;
1290     }
1291   }
1292   desc->node_builder.Attr(attr_name, shapes);
1293   status->status = Status::OK();
1294 }
1295 
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1296 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1297                       TF_Tensor* value, TF_Status* status) {
1298   Tensor t;
1299   status->status = TF_TensorToTensor(value, &t);
1300   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1301 }
1302 
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1303 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1304                           TF_Tensor* const* values, int num_values,
1305                           TF_Status* status) {
1306   status->status = Status::OK();
1307   std::vector<Tensor> t;
1308   t.reserve(num_values);
1309 
1310   for (int i = 0; i < num_values && status->status.ok(); ++i) {
1311     Tensor v;
1312     status->status = TF_TensorToTensor(values[i], &v);
1313     t.emplace_back(v);
1314   }
1315 
1316   if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1317 }
1318 
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1319 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1320                           const void* proto, size_t proto_len,
1321                           TF_Status* status) {
1322   tensorflow::AttrValue attr_value;
1323   if (!attr_value.ParseFromArray(proto, proto_len)) {
1324     status->status = InvalidArgument("Unparseable AttrValue proto");
1325     return;
1326   }
1327 
1328   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1329     if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1330         attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1331       status->status =
1332           InvalidArgument("Expected \"list\" field for \"",
1333                           tensorflow::kColocationAttrName, "\" attribute");
1334       return;
1335     }
1336     desc->colocation_constraints.clear();
1337     for (const string& location : attr_value.list().s()) {
1338       desc->colocation_constraints.insert(location);
1339     }
1340   } else {
1341     desc->node_builder.Attr(attr_name, attr_value);
1342   }
1343 
1344   status->status = Status::OK();
1345 }
1346 
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1347 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1348                                               TF_Status* status)
1349     EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1350   Node* ret = nullptr;
1351 
1352   if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1353     status->status = InvalidArgument("Duplicate node name in graph: '",
1354                                      desc->node_builder.node_name(), "'");
1355   } else {
1356     if (!desc->colocation_constraints.empty()) {
1357       desc->node_builder.Attr(
1358           tensorflow::kColocationAttrName,
1359           std::vector<string>(desc->colocation_constraints.begin(),
1360                               desc->colocation_constraints.end()));
1361     }
1362     status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1363 
1364     if (status->status.ok()) {
1365       // Run shape inference function for newly added node.
1366       status->status = desc->graph->refiner.AddNode(ret);
1367     }
1368     if (status->status.ok()) {
1369       // Add the node to the name-to-node mapping.
1370       desc->graph->name_map[ret->name()] = ret;
1371     } else if (ret != nullptr) {
1372       desc->graph->graph.RemoveNode(ret);
1373       ret = nullptr;
1374     }
1375   }
1376 
1377   delete desc;
1378 
1379   return ToOperation(ret);
1380 }
1381 
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1382 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1383                                  TF_Status* status) {
1384   mutex_lock l(desc->graph->mu);
1385   return TF_FinishOperationLocked(desc, status);
1386 }
1387 
1388 // TF_Operation functions
1389 // ----------------------------------------------------------
1390 
TF_OperationName(TF_Operation * oper)1391 const char* TF_OperationName(TF_Operation* oper) {
1392   return oper->node.name().c_str();
1393 }
1394 
TF_OperationOpType(TF_Operation * oper)1395 const char* TF_OperationOpType(TF_Operation* oper) {
1396   return oper->node.type_string().c_str();
1397 }
1398 
TF_OperationDevice(TF_Operation * oper)1399 const char* TF_OperationDevice(TF_Operation* oper) {
1400   return oper->node.requested_device().c_str();
1401 }
1402 
TF_OperationNumOutputs(TF_Operation * oper)1403 int TF_OperationNumOutputs(TF_Operation* oper) {
1404   return oper->node.num_outputs();
1405 }
1406 
TF_OperationOutputType(TF_Output oper_out)1407 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1408   return static_cast<TF_DataType>(
1409       oper_out.oper->node.output_type(oper_out.index));
1410 }
1411 
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1412 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1413                                  TF_Status* status) {
1414   NameRangeMap name_ranges;
1415   status->status =
1416       NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1417   if (!status->status.ok()) return -1;
1418   auto iter = name_ranges.find(arg_name);
1419   if (iter == name_ranges.end()) {
1420     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1421     return -1;
1422   }
1423   return iter->second.second - iter->second.first;
1424 }
1425 
TF_OperationNumInputs(TF_Operation * oper)1426 int TF_OperationNumInputs(TF_Operation* oper) {
1427   return oper->node.num_inputs();
1428 }
1429 
TF_OperationInputType(TF_Input oper_in)1430 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1431   return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1432 }
1433 
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1434 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1435                                 TF_Status* status) {
1436   NameRangeMap name_ranges;
1437   status->status =
1438       NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1439   if (!status->status.ok()) return -1;
1440   auto iter = name_ranges.find(arg_name);
1441   if (iter == name_ranges.end()) {
1442     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1443     return -1;
1444   }
1445   return iter->second.second - iter->second.first;
1446 }
1447 
TF_OperationInput(TF_Input oper_in)1448 TF_Output TF_OperationInput(TF_Input oper_in) {
1449   const tensorflow::Edge* edge;
1450   Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1451   if (!s.ok()) {
1452     return {nullptr, -1};
1453   }
1454 
1455   return {ToOperation(edge->src()), edge->src_output()};
1456 }
1457 
TF_OperationOutputNumConsumers(TF_Output oper_out)1458 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1459   int count = 0;
1460   for (const auto* edge : oper_out.oper->node.out_edges()) {
1461     if (edge->src_output() == oper_out.index) {
1462       ++count;
1463     }
1464   }
1465   return count;
1466 }
1467 
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1468 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1469                                 int max_consumers) {
1470   int count = 0;
1471   for (const auto* edge : oper_out.oper->node.out_edges()) {
1472     if (edge->src_output() == oper_out.index) {
1473       if (count < max_consumers) {
1474         consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1475       }
1476       ++count;
1477     }
1478   }
1479   return count;
1480 }
1481 
TF_OperationNumControlInputs(TF_Operation * oper)1482 int TF_OperationNumControlInputs(TF_Operation* oper) {
1483   int count = 0;
1484   for (const auto* edge : oper->node.in_edges()) {
1485     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1486       ++count;
1487     }
1488   }
1489   return count;
1490 }
1491 
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1492 int TF_OperationGetControlInputs(TF_Operation* oper,
1493                                  TF_Operation** control_inputs,
1494                                  int max_control_inputs) {
1495   int count = 0;
1496   for (const auto* edge : oper->node.in_edges()) {
1497     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1498       if (count < max_control_inputs) {
1499         control_inputs[count] = ToOperation(edge->src());
1500       }
1501       ++count;
1502     }
1503   }
1504   return count;
1505 }
1506 
TF_OperationNumControlOutputs(TF_Operation * oper)1507 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1508   int count = 0;
1509   for (const auto* edge : oper->node.out_edges()) {
1510     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1511       ++count;
1512     }
1513   }
1514   return count;
1515 }
1516 
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1517 int TF_OperationGetControlOutputs(TF_Operation* oper,
1518                                   TF_Operation** control_outputs,
1519                                   int max_control_outputs) {
1520   int count = 0;
1521   for (const auto* edge : oper->node.out_edges()) {
1522     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1523       if (count < max_control_outputs) {
1524         control_outputs[count] = ToOperation(edge->dst());
1525       }
1526       ++count;
1527     }
1528   }
1529   return count;
1530 }
1531 
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1532 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1533                                             const char* attr_name,
1534                                             TF_Status* status) {
1535   TF_AttrMetadata metadata;
1536   const auto* attr = GetAttrValue(oper, attr_name, status);
1537   if (!status->status.ok()) return metadata;
1538   switch (attr->value_case()) {
1539 #define SINGLE_CASE(kK, attr_type, size_expr) \
1540   case tensorflow::AttrValue::kK:             \
1541     metadata.is_list = 0;                     \
1542     metadata.list_size = -1;                  \
1543     metadata.type = attr_type;                \
1544     metadata.total_size = size_expr;          \
1545     break;
1546 
1547     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1548     SINGLE_CASE(kI, TF_ATTR_INT, -1);
1549     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1550     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1551     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1552     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1553                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1554     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1555 #undef SINGLE_CASE
1556 
1557     case tensorflow::AttrValue::kList:
1558       metadata.is_list = 1;
1559       metadata.list_size = 0;
1560       metadata.total_size = -1;
1561 #define LIST_CASE(field, attr_type, ...)              \
1562   if (attr->list().field##_size() > 0) {              \
1563     metadata.type = attr_type;                        \
1564     metadata.list_size = attr->list().field##_size(); \
1565     __VA_ARGS__;                                      \
1566     break;                                            \
1567   }
1568 
1569       LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
1570                 for (int i = 0; i < attr->list().s_size();
1571                      ++i) { metadata.total_size += attr->list().s(i).size(); });
1572       LIST_CASE(i, TF_ATTR_INT);
1573       LIST_CASE(f, TF_ATTR_FLOAT);
1574       LIST_CASE(b, TF_ATTR_BOOL);
1575       LIST_CASE(type, TF_ATTR_TYPE);
1576       LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1577                 for (int i = 0; i < attr->list().shape_size(); ++i) {
1578                   const auto& s = attr->list().shape(i);
1579                   metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1580                 });
1581       LIST_CASE(tensor, TF_ATTR_TENSOR);
1582       LIST_CASE(tensor, TF_ATTR_FUNC);
1583 #undef LIST_CASE
1584       // All lists empty, determine the type from the OpDef.
1585       if (metadata.list_size == 0) {
1586         for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1587           const auto& a = oper->node.op_def().attr(i);
1588           if (a.name().compare(attr_name) != 0) continue;
1589           const string& typestr = a.type();
1590           if (typestr == "list(string)") {
1591             metadata.type = TF_ATTR_STRING;
1592           } else if (typestr == "list(int)") {
1593             metadata.type = TF_ATTR_INT;
1594           } else if (typestr == "list(float)") {
1595             metadata.type = TF_ATTR_FLOAT;
1596           } else if (typestr == "list(bool)") {
1597             metadata.type = TF_ATTR_BOOL;
1598           } else if (typestr == "list(type)") {
1599             metadata.type = TF_ATTR_TYPE;
1600           } else if (typestr == "list(shape)") {
1601             metadata.type = TF_ATTR_SHAPE;
1602           } else if (typestr == "list(tensor)") {
1603             metadata.type = TF_ATTR_TENSOR;
1604           } else if (typestr == "list(func)") {
1605             metadata.type = TF_ATTR_FUNC;
1606           } else {
1607             status->status = InvalidArgument(
1608                 "Attribute '", attr_name,
1609                 "' has an empty value of an unrecognized type '", typestr, "'");
1610             return metadata;
1611           }
1612         }
1613       }
1614       break;
1615 
1616     case tensorflow::AttrValue::kPlaceholder:
1617       metadata.is_list = 0;
1618       metadata.list_size = -1;
1619       metadata.type = TF_ATTR_PLACEHOLDER;
1620       metadata.total_size = -1;
1621       break;
1622 
1623     case tensorflow::AttrValue::kFunc:
1624       metadata.is_list = 0;
1625       metadata.list_size = -1;
1626       metadata.type = TF_ATTR_FUNC;
1627       metadata.total_size = -1;
1628       break;
1629 
1630     case tensorflow::AttrValue::VALUE_NOT_SET:
1631       status->status =
1632           InvalidArgument("Attribute '", attr_name, "' has no value set");
1633       break;
1634   }
1635   return metadata;
1636 }
1637 
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1638 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1639                                void* value, size_t max_length,
1640                                TF_Status* status) {
1641   const auto* attr = GetAttrValue(oper, attr_name, status);
1642   if (!status->status.ok()) return;
1643   if (attr->value_case() != tensorflow::AttrValue::kS) {
1644     status->status =
1645         InvalidArgument("Attribute '", attr_name, "' is not a string");
1646     return;
1647   }
1648   if (max_length <= 0) {
1649     return;
1650   }
1651   const auto& s = attr->s();
1652   std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1653 }
1654 
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1655 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1656                                    void** values, size_t* lengths,
1657                                    int max_values, void* storage,
1658                                    size_t storage_size, TF_Status* status) {
1659   const auto* attr = GetAttrValue(oper, attr_name, status);
1660   if (!status->status.ok()) return;
1661   if (attr->value_case() != tensorflow::AttrValue::kList) {
1662     status->status =
1663         InvalidArgument("Value for '", attr_name, "' is not a list");
1664     return;
1665   }
1666   const auto len = std::min(max_values, attr->list().s_size());
1667   char* p = static_cast<char*>(storage);
1668   for (int i = 0; i < len; ++i) {
1669     const string& s = attr->list().s(i);
1670     values[i] = p;
1671     lengths[i] = s.size();
1672     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1673       status->status = InvalidArgument(
1674           "Not enough storage to hold the requested list of strings");
1675       return;
1676     }
1677     memcpy(values[i], s.data(), s.size());
1678     p += s.size();
1679   }
1680 }
1681 
1682 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1683   void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1684             TF_Status* status) {                                             \
1685     cpp_type v;                                                              \
1686     status->status =                                                         \
1687         tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1688     *value = static_cast<c_type>(v);                                         \
1689   }                                                                          \
1690   void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1691                   int max_values, TF_Status* status) {                       \
1692     const auto* attr = GetAttrValue(oper, attr_name, status);                \
1693     if (!status->status.ok()) return;                                        \
1694     if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1695       status->status =                                                       \
1696           InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1697       return;                                                                \
1698     }                                                                        \
1699     const auto len = std::min(max_values, attr->list().list_field##_size()); \
1700     for (int i = 0; i < len; ++i) {                                          \
1701       values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1702     }                                                                        \
1703   }
1704 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1705 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1706 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1707 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1708 #undef DEFINE_GETATTR
1709 
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1710 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1711                               int64_t* value, int num_dims, TF_Status* status) {
1712   PartialTensorShape shape;
1713   status->status =
1714       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1715   if (!status->status.ok()) return;
1716   auto len = std::min(shape.dims(), num_dims);
1717   for (int i = 0; i < len; ++i) {
1718     value[i] = shape.dim_size(i);
1719   }
1720 }
1721 
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** values,int * num_dims,int max_values,int64_t * storage,int storage_size,TF_Status * status)1722 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1723                                   int64_t** values, int* num_dims,
1724                                   int max_values, int64_t* storage,
1725                                   int storage_size, TF_Status* status) {
1726   std::vector<PartialTensorShape> shapes;
1727   status->status =
1728       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1729   if (!status->status.ok()) return;
1730   auto len = std::min(static_cast<int>(shapes.size()), max_values);
1731   int64_t* p = storage;
1732   int storage_left = storage_size;
1733   for (int i = 0; i < len; ++i) {
1734     // shapes[i].dims() == -1 for shapes with an unknown rank.
1735     int64_t n = shapes[i].dims();
1736     num_dims[i] = n;
1737     values[i] = p;
1738     if (n < 0) {
1739       continue;
1740     }
1741     if (storage_left < n) {
1742       status->status = InvalidArgument(
1743           "Not enough storage to hold the requested list of shapes");
1744       return;
1745     }
1746     storage_left -= n;
1747     for (int j = 0; j < n; ++j, ++p) {
1748       *p = shapes[i].dim_size(j);
1749     }
1750   }
1751 }
1752 
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1753 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1754                                          const char* attr_name,
1755                                          TF_Buffer* value, TF_Status* status) {
1756   const auto* attr = GetAttrValue(oper, attr_name, status);
1757   if (!status->status.ok()) return;
1758   if (attr->value_case() != tensorflow::AttrValue::kShape) {
1759     status->status =
1760         InvalidArgument("Value for '", attr_name, "' is not a shape.");
1761     return;
1762   }
1763   status->status = MessageToBuffer(attr->shape(), value);
1764 }
1765 
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1766 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1767                                              const char* attr_name,
1768                                              TF_Buffer** values, int max_values,
1769                                              TF_Status* status) {
1770   const auto* attr = GetAttrValue(oper, attr_name, status);
1771   if (!status->status.ok()) return;
1772   if (attr->value_case() != tensorflow::AttrValue::kList) {
1773     status->status =
1774         InvalidArgument("Value for '", attr_name, "' is not a list");
1775     return;
1776   }
1777   const auto len = std::min(max_values, attr->list().shape_size());
1778   for (int i = 0; i < len; ++i) {
1779     values[i] = TF_NewBuffer();
1780     status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1781     if (!status->status.ok()) {
1782       // Delete everything allocated to far, the operation has failed.
1783       for (int j = 0; j <= i; ++j) {
1784         TF_DeleteBuffer(values[j]);
1785       }
1786       return;
1787     }
1788   }
1789 }
1790 
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1791 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1792                                TF_Tensor** value, TF_Status* status) {
1793   *value = nullptr;
1794   Tensor t;
1795   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1796   if (!status->status.ok()) return;
1797   *value = TF_TensorFromTensor(t, status);
1798 }
1799 
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1800 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1801                                    TF_Tensor** values, int max_values,
1802                                    TF_Status* status) {
1803   std::vector<Tensor> ts;
1804   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1805   if (!status->status.ok()) return;
1806   const auto len = std::min(max_values, static_cast<int>(ts.size()));
1807   for (int i = 0; i < len; ++i) {
1808     values[i] = TF_TensorFromTensor(ts[i], status);
1809   }
1810 }
1811 
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1812 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1813                                    TF_Buffer* output_attr_value,
1814                                    TF_Status* status) {
1815   const auto* attr = GetAttrValue(oper, attr_name, status);
1816   if (!status->status.ok()) return;
1817   status->status = MessageToBuffer(*attr, output_attr_value);
1818 }
1819 
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1820 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1821                            TF_Status* status) {
1822   status->status = MessageToBuffer(oper->node.def(), output_node_def);
1823 }
1824 
1825 // TF_Graph functions ---------------------------------------------------------
1826 
TF_Graph()1827 TF_Graph::TF_Graph()
1828     : graph(tensorflow::OpRegistry::Global()),
1829       refiner(graph.versions().producer(), graph.op_registry()),
1830       delete_requested(false),
1831       parent(nullptr),
1832       parent_inputs(nullptr) {}
1833 
TF_NewGraph()1834 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1835 
TF_DeleteGraph(TF_Graph * g)1836 void TF_DeleteGraph(TF_Graph* g) {
1837   g->mu.lock();
1838   g->delete_requested = true;
1839   const bool del = g->sessions.empty();
1840   g->mu.unlock();
1841   if (del) delete g;
1842 }
1843 
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1844 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1845   mutex_lock l(graph->mu);
1846   auto iter = graph->name_map.find(oper_name);
1847   if (iter == graph->name_map.end()) {
1848     return nullptr;
1849   } else {
1850     return ToOperation(iter->second);
1851   }
1852 }
1853 
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1854 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1855   if (*pos == 0) {
1856     // Advance past the first sentinel nodes in every graph (the source & sink).
1857     *pos += 2;
1858   } else {
1859     // Advance to the next node.
1860     *pos += 1;
1861   }
1862 
1863   mutex_lock l(graph->mu);
1864   while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1865     Node* node = graph->graph.FindNodeId(*pos);
1866     // FindNodeId() returns nullptr for nodes that have been deleted.
1867     // We aren't currently allowing nodes to be deleted, but it is safer
1868     // to still check.
1869     if (node != nullptr) return ToOperation(node);
1870     *pos += 1;
1871   }
1872 
1873   // No more nodes.
1874   return nullptr;
1875 }
1876 
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1877 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1878                         TF_Status* status) {
1879   GraphDef def;
1880   {
1881     mutex_lock l(graph->mu);
1882     graph->graph.ToGraphDef(&def);
1883   }
1884   status->status = MessageToBuffer(def, output_graph_def);
1885 }
1886 
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1887 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1888                       TF_Buffer* output_op_def, TF_Status* status) {
1889   const OpDef* op_def;
1890   {
1891     mutex_lock l(graph->mu);
1892     status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1893     if (!status->status.ok()) return;
1894   }
1895   status->status = MessageToBuffer(*op_def, output_op_def);
1896 }
1897 
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1898 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1899                       TF_Status* status) {
1900   VersionDef versions;
1901   {
1902     mutex_lock l(graph->mu);
1903     versions = graph->graph.versions();
1904   }
1905   status->status = MessageToBuffer(versions, output_version_def);
1906 }
1907 
TF_NewImportGraphDefOptions()1908 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1909   return new TF_ImportGraphDefOptions;
1910 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1911 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1912   delete opts;
1913 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1914 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1915                                        const char* prefix) {
1916   opts->opts.prefix = prefix;
1917 }
1918 
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1919 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1920                                               unsigned char uniquify_names) {
1921   opts->opts.uniquify_names = uniquify_names;
1922 }
1923 
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1924 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1925                                                unsigned char uniquify_prefix) {
1926   opts->opts.uniquify_prefix = uniquify_prefix;
1927 }
1928 
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1929 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1930                                              const char* src_name,
1931                                              int src_index, TF_Output dst) {
1932   opts->tensor_id_data.push_back(src_name);
1933   const string& src_name_str = opts->tensor_id_data.back();
1934   // We don't need to store dst's name in tensor_id_data, since `dst` must
1935   // outlive the ImportGraphDef call.
1936   opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1937 }
1938 
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1939 void TF_ImportGraphDefOptionsRemapControlDependency(
1940     TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1941   opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1942       TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1943 }
1944 
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1945 extern void TF_ImportGraphDefOptionsAddControlDependency(
1946     TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1947   opts->opts.control_dependencies.push_back(oper->node.name());
1948 }
1949 
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1950 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1951                                              const char* oper_name, int index) {
1952   opts->tensor_id_data.push_back(oper_name);
1953   const string& oper_name_str = opts->tensor_id_data.back();
1954   opts->opts.return_tensors.emplace_back(oper_name_str, index);
1955 }
1956 
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1957 int TF_ImportGraphDefOptionsNumReturnOutputs(
1958     const TF_ImportGraphDefOptions* opts) {
1959   return opts->opts.return_tensors.size();
1960 }
1961 
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1962 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1963                                                 const char* oper_name) {
1964   opts->opts.return_nodes.push_back(oper_name);
1965 }
1966 
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1967 int TF_ImportGraphDefOptionsNumReturnOperations(
1968     const TF_ImportGraphDefOptions* opts) {
1969   return opts->opts.return_nodes.size();
1970 }
1971 
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1972 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1973                                            int* num_outputs,
1974                                            TF_Output** outputs) {
1975   *num_outputs = results->return_tensors.size();
1976   *outputs = results->return_tensors.data();
1977 }
1978 
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1979 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1980                                               int* num_opers,
1981                                               TF_Operation*** opers) {
1982   *num_opers = results->return_nodes.size();
1983   *opers = results->return_nodes.data();
1984 }
1985 
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1986 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1987     TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1988     const char*** src_names, int** src_indexes) {
1989   *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1990   *src_names = results->missing_unused_key_names.data();
1991   *src_indexes = results->missing_unused_key_indexes.data();
1992 }
1993 
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1994 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1995   delete results;
1996 }
1997 
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1998 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1999                                       const TF_ImportGraphDefOptions* opts,
2000                                       TF_ImportGraphDefResults* tf_results,
2001                                       TF_Status* status)
2002     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2003   const int last_node_id = graph->graph.num_node_ids();
2004   tensorflow::ImportGraphDefResults results;
2005   status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2006                                               &graph->refiner, &results);
2007   if (!status->status.ok()) return;
2008 
2009   // Add new nodes to name_map
2010   for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2011     auto* node = graph->graph.FindNodeId(i);
2012     if (node != nullptr) graph->name_map[node->name()] = node;
2013   }
2014 
2015   // Populate return_tensors
2016   DCHECK(tf_results->return_tensors.empty());
2017   tf_results->return_tensors.resize(results.return_tensors.size());
2018   for (int i = 0; i < results.return_tensors.size(); ++i) {
2019     tf_results->return_tensors[i].oper =
2020         ToOperation(results.return_tensors[i].first);
2021     tf_results->return_tensors[i].index = results.return_tensors[i].second;
2022   }
2023 
2024   // Populate return_nodes
2025   DCHECK(tf_results->return_nodes.empty());
2026   tf_results->return_nodes.resize(results.return_nodes.size());
2027   for (int i = 0; i < results.return_nodes.size(); ++i) {
2028     tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2029   }
2030 
2031   // Populate missing unused map keys
2032   DCHECK(tf_results->missing_unused_key_names.empty());
2033   DCHECK(tf_results->missing_unused_key_indexes.empty());
2034   DCHECK(tf_results->missing_unused_key_names_data.empty());
2035 
2036   size_t size = results.missing_unused_input_map_keys.size();
2037   tf_results->missing_unused_key_names.resize(size);
2038   tf_results->missing_unused_key_indexes.resize(size);
2039 
2040   for (int i = 0; i < size; ++i) {
2041     TensorId id = results.missing_unused_input_map_keys[i];
2042     tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
2043     tf_results->missing_unused_key_names[i] =
2044         tf_results->missing_unused_key_names_data.back().c_str();
2045     tf_results->missing_unused_key_indexes[i] = id.second;
2046   }
2047 }
2048 
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2049 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2050     TF_Graph* graph, const TF_Buffer* graph_def,
2051     const TF_ImportGraphDefOptions* options, TF_Status* status) {
2052   GraphDef def;
2053   if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2054     status->status = InvalidArgument("Invalid GraphDef");
2055     return nullptr;
2056   }
2057   auto results = new TF_ImportGraphDefResults();
2058   mutex_lock l(graph->mu);
2059   GraphImportGraphDefLocked(graph, def, options, results, status);
2060   if (!status->status.ok()) {
2061     delete results;
2062     return nullptr;
2063   }
2064   return results;
2065 }
2066 
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)2067 void TF_GraphImportGraphDefWithReturnOutputs(
2068     TF_Graph* graph, const TF_Buffer* graph_def,
2069     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2070     int num_return_outputs, TF_Status* status) {
2071   if (num_return_outputs != options->opts.return_tensors.size()) {
2072     status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2073                                      options->opts.return_tensors.size(),
2074                                      ", got ", num_return_outputs);
2075     return;
2076   }
2077   if (num_return_outputs > 0 && return_outputs == nullptr) {
2078     status->status = InvalidArgument(
2079         "'return_outputs' must be preallocated to length ", num_return_outputs);
2080     return;
2081   }
2082   GraphDef def;
2083   if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2084     status->status = InvalidArgument("Invalid GraphDef");
2085     return;
2086   }
2087   TF_ImportGraphDefResults results;
2088   mutex_lock l(graph->mu);
2089   GraphImportGraphDefLocked(graph, def, options, &results, status);
2090   DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2091   memcpy(return_outputs, results.return_tensors.data(),
2092          num_return_outputs * sizeof(TF_Output));
2093 }
2094 
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2095 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2096                             const TF_ImportGraphDefOptions* options,
2097                             TF_Status* status) {
2098   TF_ImportGraphDefResults* results =
2099       TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2100   TF_DeleteImportGraphDefResults(results);
2101 }
2102 
2103 // While loop functions -------------------------------------------------------
2104 
2105 namespace {
2106 
2107 #ifndef __ANDROID__
2108 
2109 // Creates a placeholder representing an input to the cond or body graph.
2110 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)2111 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2112                  TF_Output* input, TF_Status* status) {
2113   TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2114   TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2115   // TODO(skyewm): set placeholder shape
2116   TF_Operation* oper = TF_FinishOperation(desc, status);
2117   if (!status->status.ok()) return false;
2118   *input = {oper, 0};
2119   return true;
2120 }
2121 
2122 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2123 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
2124 // will be prepended to copied node names. `control_deps` are nodes in
2125 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
2126 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2127 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)2128 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2129                  tensorflow::ShapeRefiner* dst_refiner,
2130                  const TF_Output* src_inputs,
2131                  const std::vector<tensorflow::Output>& dst_inputs,
2132                  const string& prefix,
2133                  const std::vector<tensorflow::Operation>& control_deps,
2134                  const TF_Output* nodes_to_return, int nreturn_nodes,
2135                  std::vector<tensorflow::Output>* return_nodes) {
2136   DCHECK(return_nodes != nullptr);
2137   GraphDef gdef;
2138   src_graph->ToGraphDef(&gdef);
2139 
2140   tensorflow::ImportGraphDefOptions opts;
2141   opts.prefix = prefix;
2142 
2143   for (int i = 0; i < dst_inputs.size(); ++i) {
2144     opts.input_map[ToTensorId(src_inputs[i])] =
2145         TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2146   }
2147   opts.skip_mapped_nodes = true;
2148 
2149   for (const tensorflow::Operation& op : control_deps) {
2150     opts.control_dependencies.push_back(op.node()->name());
2151   }
2152 
2153   for (int i = 0; i < nreturn_nodes; ++i) {
2154     opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2155   }
2156 
2157   // TODO(skyewm): change to OutputTensor
2158   tensorflow::ImportGraphDefResults results;
2159   TF_RETURN_IF_ERROR(
2160       ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2161 
2162   for (const auto& pair : results.return_tensors) {
2163     return_nodes->emplace_back(pair.first, pair.second);
2164   }
2165   return Status::OK();
2166 }
2167 
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)2168 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2169   if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2170       params.cond_graph->parent == nullptr ||
2171       params.cond_graph->parent != params.body_graph->parent ||
2172       params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2173       params.ninputs <= 0 || params.cond_inputs == nullptr ||
2174       params.body_inputs == nullptr || params.body_outputs == nullptr) {
2175     s->status = InvalidArgument(
2176         "TF_WhileParams must be created by successful TF_NewWhile() call");
2177     return false;
2178   }
2179   return true;
2180 }
2181 
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)2182 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2183   if (params.cond_output.oper == nullptr) {
2184     s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2185     return false;
2186   }
2187   for (int i = 0; i < params.ninputs; ++i) {
2188     if (params.body_outputs[i].oper == nullptr) {
2189       s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2190                                   "field isn't set");
2191       return false;
2192     }
2193   }
2194   if (params.name == nullptr) {
2195     s->status = InvalidArgument("TF_WhileParams `name` field is null");
2196     return false;
2197   }
2198   return true;
2199 }
2200 
2201 #endif  // __ANDROID__
2202 
FreeWhileResources(const TF_WhileParams * params)2203 void FreeWhileResources(const TF_WhileParams* params) {
2204   TF_DeleteGraph(params->cond_graph);
2205   TF_DeleteGraph(params->body_graph);
2206   delete[] params->cond_inputs;
2207   delete[] params->body_inputs;
2208   delete[] params->body_outputs;
2209 }
2210 
EmptyWhileParams()2211 TF_WhileParams EmptyWhileParams() {
2212   return {0,       nullptr, nullptr, {nullptr, 0},
2213           nullptr, nullptr, nullptr, nullptr};
2214 }
2215 
2216 }  // namespace
2217 
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)2218 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2219                            TF_Status* status) {
2220 #ifdef __ANDROID__
2221   status->status = tensorflow::errors::Unimplemented(
2222       "Creating while loops is not supported in Android. File a bug at "
2223       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2224       "important to you");
2225   return EmptyWhileParams();
2226 #else
2227   if (ninputs == 0) {
2228     status->status =
2229         InvalidArgument("TF_NewWhile() must be passed at least one input");
2230     return EmptyWhileParams();
2231   }
2232 
2233   TF_Graph* cond_graph = TF_NewGraph();
2234   TF_Graph* body_graph = TF_NewGraph();
2235   cond_graph->parent = g;
2236   cond_graph->parent_inputs = inputs;
2237   body_graph->parent = g;
2238   body_graph->parent_inputs = inputs;
2239 
2240   TF_Output* cond_inputs = new TF_Output[ninputs];
2241   TF_Output cond_output = {nullptr, -1};
2242   TF_Output* body_inputs = new TF_Output[ninputs];
2243   TF_Output* body_outputs = new TF_Output[ninputs];
2244   for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2245   const char* name = nullptr;
2246 
2247   for (int i = 0; i < ninputs; ++i) {
2248     // TODO(skyewm): prefix names with underscore (requires some plumbing)
2249     if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2250                      &cond_inputs[i], status)) {
2251       break;
2252     }
2253     if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2254                      &body_inputs[i], status)) {
2255       break;
2256     }
2257   }
2258 
2259   TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
2260                            body_graph, body_inputs, body_outputs, name};
2261 
2262   if (!status->status.ok()) {
2263     FreeWhileResources(&params);
2264     return EmptyWhileParams();
2265   }
2266   return params;
2267 #endif  // __ANDROID__
2268 }
2269 
2270 #ifndef __ANDROID__
2271 namespace {
2272 
2273 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2274 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2275                           TF_Output* outputs) {
2276   if (!ValidateInputWhileParams(*params, status)) return;
2277 
2278   TF_Graph* parent = params->cond_graph->parent;
2279   TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2280   int num_loop_vars = params->ninputs;
2281 
2282   mutex_lock l(parent->mu);
2283 
2284   // 'cond_fn' copies the cond graph into the parent graph.
2285   tensorflow::ops::CondGraphBuilderFn cond_fn =
2286       [params, parent](const tensorflow::Scope& scope,
2287                        const std::vector<tensorflow::Output>& inputs,
2288                        tensorflow::Output* output) {
2289         DCHECK_EQ(scope.graph(), &parent->graph);
2290         std::vector<tensorflow::Output> cond_output;
2291         TF_RETURN_IF_ERROR(CopyGraph(
2292             &params->cond_graph->graph, &parent->graph, &parent->refiner,
2293             params->cond_inputs, inputs, scope.impl()->name(),
2294             scope.impl()->control_deps(), &params->cond_output,
2295             /* nreturn_nodes */ 1, &cond_output));
2296         *output = cond_output[0];
2297         return Status::OK();
2298       };
2299 
2300   // 'body_fn' copies the body graph into the parent graph.
2301   tensorflow::ops::BodyGraphBuilderFn body_fn =
2302       [params, parent, num_loop_vars](
2303           const tensorflow::Scope& scope,
2304           const std::vector<tensorflow::Output>& inputs,
2305           std::vector<tensorflow::Output>* outputs) {
2306         DCHECK_EQ(scope.graph(), &parent->graph);
2307         TF_RETURN_IF_ERROR(
2308             CopyGraph(&params->body_graph->graph, &parent->graph,
2309                       &parent->refiner, params->body_inputs, inputs,
2310                       scope.impl()->name(), scope.impl()->control_deps(),
2311                       params->body_outputs, num_loop_vars, outputs));
2312         return Status::OK();
2313       };
2314 
2315   // Create the while loop using an internal scope.
2316   tensorflow::Scope scope =
2317       NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2318           .NewSubScope(params->name);
2319 
2320   const int first_new_node_id = parent->graph.num_node_ids();
2321 
2322   tensorflow::OutputList loop_outputs;
2323   status->status = tensorflow::ops::BuildWhileLoop(
2324       scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2325       body_fn, params->name, &loop_outputs);
2326 
2327   // Update name_map with newly-created ops.
2328   // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2329   // a bad status. Once we fix this, we may want to return early instead of
2330   // executing the following code.
2331   for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2332     Node* new_node = parent->graph.FindNodeId(i);
2333     if (new_node == nullptr) continue;
2334     parent->name_map[new_node->name()] = new_node;
2335   }
2336 
2337   // Populate 'outputs'.
2338   DCHECK_LE(loop_outputs.size(), num_loop_vars);
2339   for (int i = 0; i < loop_outputs.size(); ++i) {
2340     outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2341   }
2342 }
2343 
2344 }  // namespace
2345 #endif  // __ANDROID__
2346 
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2347 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2348                     TF_Output* outputs) {
2349 #ifdef __ANDROID__
2350   status->status = tensorflow::errors::Unimplemented(
2351       "Creating while loops is not supported in Android. File a bug at "
2352       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2353       "important to you");
2354 #else
2355   // If it appears the caller created or modified `params`, don't free resources
2356   if (!ValidateConstWhileParams(*params, status)) return;
2357   TF_FinishWhileHelper(params, status, outputs);
2358   FreeWhileResources(params);
2359 #endif  // __ANDROID__
2360 }
2361 
TF_AbortWhile(const TF_WhileParams * params)2362 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2363 
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2364 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2365                      TF_Output* dx, TF_Status* status, TF_Output* dy) {
2366 #ifdef __ANDROID__
2367   status->status = tensorflow::errors::Unimplemented(
2368       "Adding gradients is not supported in Android. File a bug at "
2369       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2370       "important to you");
2371 #else
2372   std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2373   std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2374   std::vector<tensorflow::Output> dy_arg;
2375 
2376   {
2377     // We need to hold on to the lock while we have a scope that uses TF_Graph.
2378     mutex_lock graph_lock(g->mu);
2379 
2380     const int first_new_node_id = g->graph.num_node_ids();
2381 
2382     tensorflow::Scope scope =
2383         NewInternalScope(&g->graph, &status->status, &g->refiner)
2384             .NewSubScope("gradients");
2385 
2386     if (dx != nullptr) {
2387       std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2388       status->status =
2389           AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2390     } else {
2391       status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2392     }
2393 
2394     // Update g->name_map with the name_map from the scope, which will contain
2395     // the new gradient ops.
2396     for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2397       Node* n = g->graph.FindNodeId(i);
2398       if (n == nullptr) continue;
2399       g->name_map[n->name()] = n;
2400     }
2401   }
2402 
2403   // Unpack the results from grad_outputs_arg.
2404   TFOutputsFromOutputs(dy_arg, dy);
2405 #endif  // __ANDROID__
2406 }
2407 
2408 // TF_Session functions ----------------------------------------------
2409 
TF_Session(tensorflow::Session * s,TF_Graph * g)2410 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2411     : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) {
2412   if (s->LocalDeviceManager(&device_mgr).ok()) {
2413     devices = device_mgr->ListDevices();
2414   }
2415 }
2416 
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2417 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2418                           TF_Status* status) {
2419   Session* session;
2420   status->status = NewSession(opt->options, &session);
2421   if (status->status.ok()) {
2422     TF_Session* new_session = new TF_Session(session, graph);
2423     if (graph != nullptr) {
2424       mutex_lock l(graph->mu);
2425       graph->sessions[new_session] = Status::OK();
2426     }
2427     return new_session;
2428   } else {
2429     DCHECK_EQ(nullptr, session);
2430     return nullptr;
2431   }
2432 }
2433 
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2434 TF_Session* TF_LoadSessionFromSavedModel(
2435     const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2436     const char* export_dir, const char* const* tags, int tags_len,
2437     TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2438 // TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
2439 // the tensorflow/cc/saved_model:loader build target is Android friendly.
2440 #ifdef __ANDROID__
2441   status->status = tensorflow::errors::Unimplemented(
2442       "Loading a SavedModel is not supported in Android. File a bug at "
2443       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2444       "important to you");
2445   return nullptr;
2446 #else
2447   mutex_lock l(graph->mu);
2448   if (!graph->name_map.empty()) {
2449     status->status = InvalidArgument("Graph is non-empty.");
2450     return nullptr;
2451   }
2452 
2453   RunOptions run_options_proto;
2454   if (run_options != nullptr && !run_options_proto.ParseFromArray(
2455                                     run_options->data, run_options->length)) {
2456     status->status = InvalidArgument("Unparseable RunOptions proto");
2457     return nullptr;
2458   }
2459 
2460   std::unordered_set<string> tag_set;
2461   for (int i = 0; i < tags_len; i++) {
2462     tag_set.insert(string(tags[i]));
2463   }
2464 
2465   tensorflow::SavedModelBundle bundle;
2466   status->status =
2467       tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2468                                  export_dir, tag_set, &bundle);
2469   if (!status->status.ok()) return nullptr;
2470 
2471   // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2472   // extends using GraphDefs. The Graph instance is different, but equivalent
2473   // to the one used to create the session.
2474   //
2475   // TODO(jhseu): When Session is modified to take Graphs instead of
2476   // GraphDefs, return the Graph generated in LoadSavedModel().
2477   TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2478   TF_ImportGraphDefResults results;
2479   GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2480                             import_opts, &results, status);
2481   TF_DeleteImportGraphDefOptions(import_opts);
2482   if (TF_GetCode(status) != TF_OK) return nullptr;
2483 
2484   if (meta_graph_def != nullptr) {
2485     status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2486     if (!status->status.ok()) return nullptr;
2487   }
2488 
2489   TF_Session* session = new TF_Session(bundle.session.release(), graph);
2490 
2491   graph->sessions[session] = Status::OK();
2492   session->last_num_graph_nodes = graph->graph.num_node_ids();
2493   return session;
2494 #endif  // __ANDROID__
2495 }
2496 
TF_CloseSession(TF_Session * s,TF_Status * status)2497 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2498   status->status = s->session->Close();
2499 }
2500 
TF_DeleteSession(TF_Session * s,TF_Status * status)2501 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2502   status->status = Status::OK();
2503   TF_Graph* const graph = s->graph;
2504   if (graph != nullptr) {
2505     graph->mu.lock();
2506     graph->sessions.erase(s);
2507     const bool del = graph->delete_requested && graph->sessions.empty();
2508     graph->mu.unlock();
2509     if (del) delete graph;
2510   }
2511   delete s->session;
2512   delete s;
2513 }
2514 
2515 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2516 // directly, instead of requiring us to serialize to a GraphDef and
2517 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)2518 static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
2519   if (session->graph != nullptr) {
2520     mutex_lock session_lock(session->mu);
2521     session->graph->mu.lock();
2522     const Graph& graph = session->graph->graph;
2523 
2524     status->status = session->graph->sessions[session];
2525     if (!status->status.ok()) {
2526       session->graph->mu.unlock();
2527       return false;
2528     }
2529 
2530     const auto num_nodes = graph.num_node_ids();
2531     if (session->last_num_graph_nodes < num_nodes) {
2532       status->status = tensorflow::ValidateNoCycles(session->graph->graph);
2533       if (!status->status.ok()) {
2534         session->graph->mu.unlock();
2535         return false;
2536       }
2537 
2538       GraphDef graph_def;
2539       *graph_def.mutable_versions() = graph.versions();
2540       // Fill graph_def with nodes with ids in the range
2541       // [session->last_num_graph_nodes, num_nodes), that is the nodes
2542       // added since the last TF_SessionRun() call.
2543       for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
2544         Node* const node = graph.FindNodeId(id);
2545         if (node != nullptr && node->IsOp()) {
2546           NodeDef* const node_def = graph_def.add_node();
2547           *node_def = node->def();
2548         }
2549       }
2550       *graph_def.mutable_library() = graph.flib_def().ToProto();
2551       session->graph->mu.unlock();
2552       status->status = session->session->Extend(graph_def);
2553       if (!status->status.ok()) {
2554         // Contract is we always delete input_values[i].
2555         return false;
2556       }
2557       // Note: session->session is not modified if Extend() fails, so
2558       // we only set last_num_graph_nodes if it succeeds.
2559       session->last_num_graph_nodes = num_nodes;
2560     } else {
2561       session->graph->mu.unlock();
2562     }
2563   }
2564   return true;
2565 }
2566 
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2567 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2568                    const TF_Output* inputs, TF_Tensor* const* input_values,
2569                    int ninputs, const TF_Output* outputs,
2570                    TF_Tensor** output_values, int noutputs,
2571                    const TF_Operation* const* target_opers, int ntargets,
2572                    TF_Buffer* run_metadata, TF_Status* status) {
2573   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2574   // directly, instead of requiring us to serialize to a GraphDef and
2575   // call Session::Extend().
2576   if (!ExtendSessionGraphHelper(session, status)) {
2577     return;
2578   }
2579 
2580   TF_Run_Setup(noutputs, output_values, status);
2581 
2582   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2583   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2584   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2585   for (int i = 0; i < ninputs; ++i) {
2586     input_pairs[i].first = OutputName(inputs[i]);
2587   }
2588 
2589   // Convert from TF_Output to string names.
2590   std::vector<string> output_names(noutputs);
2591   for (int i = 0; i < noutputs; ++i) {
2592     output_names[i] = OutputName(outputs[i]);
2593   }
2594 
2595   // Convert from TF_Operation* to string names.
2596   std::vector<string> target_names(ntargets);
2597   for (int i = 0; i < ntargets; ++i) {
2598     target_names[i] = target_opers[i]->node.name();
2599   }
2600 
2601   // Actually run.
2602   TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2603                 output_names, output_values, target_names, run_metadata,
2604                 status);
2605 }
2606 
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2607 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2608                          int ninputs, const TF_Output* outputs, int noutputs,
2609                          const TF_Operation* const* target_opers, int ntargets,
2610                          const char** handle, TF_Status* status) {
2611   *handle = nullptr;
2612 
2613   if (!ExtendSessionGraphHelper(session, status)) {
2614     return;
2615   }
2616 
2617   std::vector<string> input_names(ninputs);
2618   for (int i = 0; i < ninputs; ++i) {
2619     input_names[i] = OutputName(inputs[i]);
2620   }
2621 
2622   std::vector<string> output_names(noutputs);
2623   for (int i = 0; i < noutputs; ++i) {
2624     output_names[i] = OutputName(outputs[i]);
2625   }
2626 
2627   std::vector<string> target_names(ntargets);
2628   for (int i = 0; i < ntargets; ++i) {
2629     target_names[i] = target_opers[i]->node.name();
2630   }
2631 
2632   string new_handle;
2633   status->status = session->session->PRunSetup(input_names, output_names,
2634                                                target_names, &new_handle);
2635   if (status->status.ok()) {
2636     char* buf = new char[new_handle.size() + 1];
2637     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2638     *handle = buf;
2639   }
2640 }
2641 
TF_DeletePRunHandle(const char * handle)2642 void TF_DeletePRunHandle(const char* handle) {
2643   delete[] handle;
2644   // TODO(suharshs): Free up any resources held by the partial run state.
2645 }
2646 
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2647 void TF_SessionPRun(TF_Session* session, const char* handle,
2648                     const TF_Output* inputs, TF_Tensor* const* input_values,
2649                     int ninputs, const TF_Output* outputs,
2650                     TF_Tensor** output_values, int noutputs,
2651                     const TF_Operation* const* target_opers, int ntargets,
2652                     TF_Status* status) {
2653   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2654   // directly, instead of requiring us to serialize to a GraphDef and
2655   // call Session::Extend().
2656   if (!ExtendSessionGraphHelper(session, status)) {
2657     return;
2658   }
2659 
2660   TF_Run_Setup(noutputs, output_values, status);
2661 
2662   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2663   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2664   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2665   for (int i = 0; i < ninputs; ++i) {
2666     input_pairs[i].first = OutputName(inputs[i]);
2667   }
2668 
2669   // Convert from TF_Output to string names.
2670   std::vector<string> output_names(noutputs);
2671   for (int i = 0; i < noutputs; ++i) {
2672     output_names[i] = OutputName(outputs[i]);
2673   }
2674 
2675   // Convert from TF_Operation* to string names.
2676   std::vector<string> target_names(ntargets);
2677   for (int i = 0; i < ntargets; ++i) {
2678     target_names[i] = target_opers[i]->node.name();
2679   }
2680 
2681   TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2682                 output_values, target_names, nullptr, status);
2683 }
2684 
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2685 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2686   tensorflow::OpList op_list;
2687   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2688     status->status = InvalidArgument("Unparseable OpList");
2689     return nullptr;
2690   }
2691   status->status = Status::OK();
2692   return new TF_ApiDefMap(op_list);
2693 }
2694 
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2695 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2696 
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2697 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2698                      size_t text_len, TF_Status* status) {
2699 #ifdef __ANDROID__
2700   status->status = tensorflow::errors::Unimplemented(
2701       "ApiDefMap is not supported in Android.");
2702 #else
2703   mutex_lock l(api_def_map->lock);
2704   if (api_def_map->update_docs_called) {
2705     status->status = FailedPrecondition(
2706         "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2707         "called.");
2708     return;
2709   }
2710   string api_def_text(text, text_len);
2711   status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2712 #endif  // __ANDROID__
2713 }
2714 
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2715 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2716                            size_t name_len, TF_Status* status) {
2717 #ifdef __ANDROID__
2718   status->status = tensorflow::errors::Unimplemented(
2719       "ApiDefMap is not supported in Android.");
2720   return nullptr;
2721 #else
2722   mutex_lock l(api_def_map->lock);
2723   if (!api_def_map->update_docs_called) {
2724     api_def_map->api_def_map.UpdateDocs();
2725     api_def_map->update_docs_called = true;
2726   }
2727   string name_str(name, name_len);
2728   const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2729 
2730   TF_Buffer* ret = TF_NewBuffer();
2731   status->status = MessageToBuffer(*api_def, ret);
2732   return ret;
2733 #endif  // __ANDROID__
2734 }
2735 }  // end extern "C"
2736