• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22 
23 // Required for IS_MOBILE_PLATFORM
24 #include "tensorflow/core/platform/platform.h"  // NOLINT
25 
26 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
27 #include "tensorflow/cc/framework/gradients.h"
28 #include "tensorflow/cc/framework/ops.h"
29 #include "tensorflow/cc/framework/scope_internal.h"
30 #include "tensorflow/cc/ops/while_loop.h"
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/core/distributed_runtime/server_lib.h"
33 #include "tensorflow/core/framework/op_gen_lib.h"
34 #include "tensorflow/core/kernels/logging_ops.h"
35 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
36 #include "tensorflow/c/c_api_internal.h"
37 #include "tensorflow/core/common_runtime/device_mgr.h"
38 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/allocation_description.pb.h"
41 #include "tensorflow/core/framework/kernel_def.pb.h"
42 #include "tensorflow/core/framework/log_memory.h"
43 #include "tensorflow/core/framework/node_def_util.h"
44 #include "tensorflow/core/framework/op_kernel.h"
45 #include "tensorflow/core/framework/partial_tensor_shape.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.h"
51 #include "tensorflow/core/framework/versions.pb.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_constructor.h"
54 #include "tensorflow/core/graph/node_builder.h"
55 #include "tensorflow/core/graph/validate.h"
56 #include "tensorflow/core/lib/core/coding.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/lib/core/stringpiece.h"
60 #include "tensorflow/core/lib/gtl/array_slice.h"
61 #include "tensorflow/core/lib/strings/str_util.h"
62 #include "tensorflow/core/lib/strings/strcat.h"
63 #include "tensorflow/core/platform/mem.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/protobuf.h"
66 #include "tensorflow/core/platform/thread_annotations.h"
67 #include "tensorflow/core/platform/types.h"
68 #include "tensorflow/core/public/session.h"
69 #include "tensorflow/core/public/version.h"
70 
71 // The implementation below is at the top level instead of the
72 // brain namespace because we are defining 'extern "C"' functions.
73 using tensorflow::AllocationDescription;
74 using tensorflow::DataType;
75 using tensorflow::ExtendSessionGraphHelper;
76 using tensorflow::Graph;
77 using tensorflow::GraphDef;
78 using tensorflow::mutex_lock;
79 using tensorflow::NameRangeMap;
80 using tensorflow::NameRangesForNode;
81 using tensorflow::NewSession;
82 using tensorflow::Node;
83 using tensorflow::NodeBuilder;
84 using tensorflow::NodeDef;
85 using tensorflow::OpDef;
86 using tensorflow::OpRegistry;
87 using tensorflow::OutputTensor;
88 using tensorflow::PartialTensorShape;
89 using tensorflow::RunMetadata;
90 using tensorflow::RunOptions;
91 using tensorflow::Session;
92 using tensorflow::Status;
93 using tensorflow::string;
94 using tensorflow::Tensor;
95 using tensorflow::TensorBuffer;
96 using tensorflow::TensorId;
97 using tensorflow::TensorShape;
98 using tensorflow::TensorShapeProto;
99 using tensorflow::VersionDef;
100 using tensorflow::error::Code;
101 using tensorflow::errors::FailedPrecondition;
102 using tensorflow::errors::InvalidArgument;
103 using tensorflow::gtl::ArraySlice;
104 using tensorflow::strings::StrCat;
105 
106 extern "C" {
107 
108 // --------------------------------------------------------------------------
TF_Version()109 const char* TF_Version() { return TF_VERSION_STRING; }
110 
111 // --------------------------------------------------------------------------
TF_DataTypeSize(TF_DataType dt)112 size_t TF_DataTypeSize(TF_DataType dt) {
113   return static_cast<size_t>(
114       tensorflow::DataTypeSize(static_cast<DataType>(dt)));
115 }
116 
117 // --------------------------------------------------------------------------
118 
TF_NewStatus()119 TF_Status* TF_NewStatus() { return new TF_Status; }
120 
TF_DeleteStatus(TF_Status * s)121 void TF_DeleteStatus(TF_Status* s) { delete s; }
122 
TF_SetStatus(TF_Status * s,TF_Code code,const char * msg)123 void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
124   if (code == TF_OK) {
125     s->status = Status::OK();
126     return;
127   }
128   s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
129 }
130 
TF_GetCode(const TF_Status * s)131 TF_Code TF_GetCode(const TF_Status* s) {
132   return static_cast<TF_Code>(s->status.code());
133 }
134 
TF_Message(const TF_Status * s)135 const char* TF_Message(const TF_Status* s) {
136   return s->status.error_message().c_str();
137 }
138 
139 // --------------------------------------------------------------------------
140 
141 namespace {
142 class TF_ManagedBuffer : public TensorBuffer {
143  public:
TF_ManagedBuffer(void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)144   TF_ManagedBuffer(void* data, size_t len,
145                    void (*deallocator)(void* data, size_t len, void* arg),
146                    void* deallocator_arg)
147       : TensorBuffer(data),
148         len_(len),
149         deallocator_(deallocator),
150         deallocator_arg_(deallocator_arg) {}
151 
152   const size_t len_;
153   void (*const deallocator_)(void* data, size_t len, void* arg);
154   void* const deallocator_arg_;
155 
~TF_ManagedBuffer()156   ~TF_ManagedBuffer() override {
157     (*deallocator_)(data(), len_, deallocator_arg_);
158   }
159 
size() const160   size_t size() const override { return len_; }
root_buffer()161   TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const162   void FillAllocationDescription(AllocationDescription* proto) const override {
163     tensorflow::int64 rb = size();
164     proto->set_requested_bytes(rb);
165     proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
166   }
167 
168   // Prevents input forwarding from mutating this buffer.
OwnsMemory() const169   bool OwnsMemory() const override { return false; }
170 };
171 
allocate_tensor(const char * operation,size_t len)172 void* allocate_tensor(const char* operation, size_t len) {
173   void* data =
174       tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
175   if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
176     tensorflow::LogMemory::RecordRawAllocation(
177         operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
178         len, data, tensorflow::cpu_allocator());
179   }
180   return data;
181 }
182 
deallocate_buffer(void * data,size_t len,void * arg)183 void deallocate_buffer(void* data, size_t len, void* arg) {
184   if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
185     tensorflow::LogMemory::RecordRawDeallocation(
186         "TensorFlow C Api",
187         tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
188         tensorflow::cpu_allocator(), false);
189   }
190   tensorflow::cpu_allocator()->DeallocateRaw(data);
191 }
192 
193 }  // namespace
194 
~TF_Tensor()195 TF_Tensor::~TF_Tensor() { buffer->Unref(); }
196 
TF_AllocateTensor(TF_DataType dtype,const int64_t * dims,int num_dims,size_t len)197 TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
198                              int num_dims, size_t len) {
199   void* data = allocate_tensor("TF_AllocateTensor", len);
200   return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
201                       nullptr);
202 }
203 
TF_NewTensor(TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)204 TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
205                         void* data, size_t len,
206                         void (*deallocator)(void* data, size_t len, void* arg),
207                         void* deallocator_arg) {
208   std::vector<tensorflow::int64> dimvec(num_dims);
209   for (int i = 0; i < num_dims; ++i) {
210     dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
211   }
212 
213   TF_ManagedBuffer* buf = nullptr;
214   if (dtype != TF_STRING && dtype != TF_RESOURCE &&
215       tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
216       reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
217           0) {
218     // TF_STRING and TF_RESOURCE tensors have a different representation in
219     // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
220     // (any alignment requirements will be taken care of by TF_TensorToTensor
221     // and TF_TensorFromTensor).
222     //
223     // Other types have the same representation, so copy only if it is safe to
224     // do so.
225     buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len,
226                                deallocate_buffer, nullptr);
227     std::memcpy(buf->data(), data, len);
228     // Free the original buffer.
229     deallocator(data, len, deallocator_arg);
230   } else {
231     buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
232   }
233 
234   TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
235   size_t elem_size = TF_DataTypeSize(dtype);
236   if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
237     delete ret;
238     return nullptr;
239   }
240   return ret;
241 }
242 
TF_TensorMaybeMove(TF_Tensor * tensor)243 TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
244   // It is safe to move the Tensor if and only if we own the unique reference to
245   // it. In that case, we might as well not delete and reallocate, but a future
246   // implementation might need to do so.
247   TensorBuffer* buf = tensor->buffer;
248   if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
249       buf->OwnsMemory()) {
250     return tensor;
251   }
252   return nullptr;
253 }
254 
TF_DeleteTensor(TF_Tensor * t)255 void TF_DeleteTensor(TF_Tensor* t) { delete t; }
256 
TF_TensorType(const TF_Tensor * t)257 TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
TF_NumDims(const TF_Tensor * t)258 int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
TF_Dim(const TF_Tensor * t,int dim_index)259 int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
260   return static_cast<int64_t>(t->shape.dim_size(dim_index));
261 }
TF_TensorByteSize(const TF_Tensor * t)262 size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
TF_TensorData(const TF_Tensor * t)263 void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
264 
TF_TensorElementCount(const TF_Tensor * t)265 int64_t TF_TensorElementCount(const TF_Tensor* t) {
266   int64_t result = 1;
267   int rank = TF_NumDims(t);
268   for (int dim = 0; dim < rank; ++dim) {
269     result *= TF_Dim(t, dim);
270   }
271   return result;
272 }
273 
274 // Returns the number of elements that would be present in a tensor with the
275 // given shape.
ShapeNumElements(const int64_t * dims,int num_dims)276 static int64_t ShapeNumElements(const int64_t* dims, int num_dims) {
277   int64_t result = 1;
278   for (int dim = 0; dim < num_dims; ++dim) {
279     result *= dims[dim];
280   }
281   return result;
282 }
283 
UnrefIfNonNull(::tensorflow::TensorBuffer * buf)284 static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) {
285   if (buf != nullptr) {
286     buf->Unref();
287   }
288 }
289 
RefIfNonNull(::tensorflow::TensorBuffer * buf)290 static void RefIfNonNull(::tensorflow::TensorBuffer* buf) {
291   if (buf != nullptr) {
292     buf->Ref();
293   }
294 }
295 
TF_TensorBitcastFrom(const TF_Tensor * from,TF_DataType type,TF_Tensor * to,const int64_t * new_dims,int num_new_dims,TF_Status * status)296 void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
297                           TF_Tensor* to, const int64_t* new_dims,
298                           int num_new_dims, TF_Status* status) {
299   TF_SetStatus(status, TF_OK, "");
300   size_t in_size = TF_DataTypeSize(TF_TensorType(from));
301   if (in_size == 0) {
302     TF_SetStatus(status, TF_INVALID_ARGUMENT,
303                  "input tensor has a zero-sized data type");
304     return;
305   }
306   size_t out_size = TF_DataTypeSize(type);
307   if (out_size == 0) {
308     TF_SetStatus(status, TF_INVALID_ARGUMENT,
309                  "output tensor has a zero-sized data type");
310     return;
311   }
312 
313   if (ShapeNumElements(new_dims, num_new_dims) * out_size !=
314       TF_TensorElementCount(from) * in_size) {
315     TF_SetStatus(status, TF_INVALID_ARGUMENT,
316                  "input tensor is not compatible with output shape");
317     return;
318   }
319 
320   tensorflow::TensorShapeProto p;
321   for (int i = 0; i < num_new_dims; ++i) {
322     p.add_dim()->set_size(new_dims[i]);
323   }
324   to->shape = tensorflow::TensorShape(p);
325   to->dtype = type;
326   if (to->buffer != from->buffer) {
327     UnrefIfNonNull(to->buffer);
328     to->buffer = from->buffer;
329     RefIfNonNull(to->buffer);
330   }
331 }
332 
333 // --------------------------------------------------------------------------
TF_StringEncode(const char * src,size_t src_len,char * dst,size_t dst_len,TF_Status * status)334 size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
335                        size_t dst_len, TF_Status* status) {
336   const size_t sz = TF_StringEncodedSize(src_len);
337   if (sz < src_len) {
338     status->status = InvalidArgument("src string is too large to encode");
339     return 0;
340   }
341   if (dst_len < sz) {
342     status->status =
343         InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
344                         src_len, "-byte string");
345     return 0;
346   }
347   dst = tensorflow::core::EncodeVarint64(dst, src_len);
348   memcpy(dst, src, src_len);
349   return sz;
350 }
351 
TF_StringDecode_Impl(const char * src,size_t src_len,const char ** dst,size_t * dst_len)352 static Status TF_StringDecode_Impl(const char* src, size_t src_len,
353                                    const char** dst, size_t* dst_len) {
354   tensorflow::uint64 len64 = 0;
355   const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
356   if (p == nullptr) {
357     return InvalidArgument("invalid string encoding or truncated src buffer");
358   }
359   if (len64 > std::numeric_limits<size_t>::max()) {
360     return InvalidArgument("encoded string is ", len64,
361                            "-bytes, which is too large for this architecture");
362   }
363   *dst = p;
364   *dst_len = static_cast<size_t>(len64);
365   return Status::OK();
366 }
367 
TF_StringDecode(const char * src,size_t src_len,const char ** dst,size_t * dst_len,TF_Status * status)368 size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
369                        size_t* dst_len, TF_Status* status) {
370   status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
371   if (TF_GetCode(status) != TF_OK) return 0;
372   return static_cast<size_t>(*dst - src) + *dst_len;
373 }
374 
TF_StringEncodedSize(size_t len)375 size_t TF_StringEncodedSize(size_t len) {
376   return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
377 }
378 
379 // --------------------------------------------------------------------------
TF_NewSessionOptions()380 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)381 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
382 
TF_SetTarget(TF_SessionOptions * options,const char * target)383 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
384   options->options.target = target;
385 }
386 
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)387 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
388                   size_t proto_len, TF_Status* status) {
389   if (!options->options.config.ParseFromArray(proto, proto_len)) {
390     status->status = InvalidArgument("Unparseable ConfigProto");
391   }
392 }
393 // --------------------------------------------------------------------------
TF_NewBuffer()394 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
395 
TF_NewBufferFromString(const void * proto,size_t proto_len)396 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
397   void* copy = tensorflow::port::Malloc(proto_len);
398   memcpy(copy, proto, proto_len);
399 
400   TF_Buffer* buf = new TF_Buffer;
401   buf->data = copy;
402   buf->length = proto_len;
403   buf->data_deallocator = [](void* data, size_t length) {
404     tensorflow::port::Free(data);
405   };
406   return buf;
407 }
408 
TF_DeleteBuffer(TF_Buffer * buffer)409 void TF_DeleteBuffer(TF_Buffer* buffer) {
410   if (buffer == nullptr) return;
411   if (buffer->data_deallocator != nullptr) {
412     (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
413                                 buffer->length);
414   }
415   delete buffer;
416 }
417 
TF_GetBuffer(TF_Buffer * buffer)418 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
419 
420 // --------------------------------------------------------------------------
421 
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)422 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
423                                               TF_Status* status) {
424   Session* session;
425   status->status = NewSession(opt->options, &session);
426   if (TF_GetCode(status) == TF_OK) {
427     return new TF_DeprecatedSession({session});
428   } else {
429     DCHECK_EQ(nullptr, session);
430     return nullptr;
431   }
432 }
433 
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)434 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
435   status->status = s->session->Close();
436 }
437 
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)438 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
439   status->status = Status::OK();
440   if (s == nullptr) return;
441   delete s->session;
442   delete s;
443 }
444 
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)445 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
446                     size_t proto_len, TF_Status* status) {
447   GraphDef g;
448   if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
449     status->status = InvalidArgument("Invalid GraphDef");
450     return;
451   }
452   status->status = s->session->Extend(g);
453 }
454 
DeleteArray(void * data,size_t size,void * arg)455 static void DeleteArray(void* data, size_t size, void* arg) {
456   DCHECK_EQ(data, arg);
457   delete[] reinterpret_cast<char*>(arg);
458 }
459 
460 }  // end extern "C"
461 
462 namespace tensorflow {
463 namespace {
464 
465 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)466 void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
467                      int ncontainers, TF_Status* status) {
468   std::vector<string> container_names(ncontainers);
469   for (int i = 0; i < ncontainers; ++i) {
470     container_names[i] = containers[i];
471   }
472 
473   status->status = Reset(opt->options, container_names);
474 }
475 
476 }  // namespace
477 }  // namespace tensorflow
478 
479 extern "C" {
480 
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)481 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
482               int ncontainers, TF_Status* status) {
483   tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
484 }
485 
486 }  // end extern "C"
487 
488 namespace tensorflow {
489 
TF_TensorToTensor(const TF_Tensor * src,Tensor * dst)490 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
491   if (src->dtype == TF_RESOURCE) {
492     if (src->shape.dims() != 0) {
493       return InvalidArgument(
494           "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
495           "shape ",
496           src->shape.DebugString());
497     }
498     *dst = Tensor(DT_RESOURCE, src->shape);
499     if (!dst->scalar<ResourceHandle>()().ParseFromString(
500             string(static_cast<const char*>(TF_TensorData(src)),
501                    TF_TensorByteSize(src)))) {
502       return InvalidArgument(
503           "Malformed TF_RESOUCE tensor: unable to parse resource handle");
504     }
505     return Status::OK();
506   }
507   if (src->dtype != TF_STRING) {
508     *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
509     return Status::OK();
510   }
511   // TF_STRING tensors require copying since Tensor class expects a sequence of
512   // string objects.
513   const tensorflow::int64 num_elements = src->shape.num_elements();
514   const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
515   const size_t src_size = TF_TensorByteSize(src);
516   if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
517       num_elements) {
518     return InvalidArgument(
519         "Malformed TF_STRING tensor; too short to hold number of elements");
520   }
521   const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
522   const char* limit = input + src_size;
523 
524   *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
525   auto dstarray = dst->flat<string>();
526   for (tensorflow::int64 i = 0; i < num_elements; ++i) {
527     tensorflow::uint64 offset =
528         reinterpret_cast<const tensorflow::uint64*>(input)[i];
529     if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
530       return InvalidArgument("Malformed TF_STRING tensor; element ", i,
531                              " out of range");
532     }
533     size_t len;
534     const char* p;
535     const char* srcp = data_start + offset;
536     Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
537     if (!status.ok()) return status;
538     dstarray(i).assign(p, len);
539   }
540   return Status::OK();
541 }
542 
543 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
544 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const TensorShape & shape)545 static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
546   static char empty;
547   tensorflow::int64 nelems = 1;
548   std::vector<tensorflow::int64> dims;
549   for (int i = 0; i < shape.dims(); ++i) {
550     dims.push_back(shape.dim_size(i));
551     nelems *= shape.dim_size(i);
552   }
553   CHECK_EQ(nelems, 0);
554   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
555                 "64-bit int types should match in size");
556   return TF_NewTensor(
557       dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
558       reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
559 }
560 
561 // Non-static for testing.
TF_TensorFromTensor(const tensorflow::Tensor & src,TF_Status * status)562 TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
563                                TF_Status* status) {
564   TF_SetStatus(status, TF_OK, "");
565   if (!src.IsInitialized()) {
566     status->status = FailedPrecondition(
567         "attempt to use a tensor with an uninitialized value");
568     return nullptr;
569   }
570   if (src.NumElements() == 0) {
571     return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
572   }
573   if (src.dtype() == DT_RESOURCE) {
574     if (src.shape().dims() != 0) {
575       status->status = InvalidArgument(
576           "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
577           src.shape().DebugString(),
578           "). Please file a bug at "
579           "https://github.com/tensorflow/tensorflow/issues/new, "
580           "ideally with a "
581           "short code snippet that reproduces this error.");
582       return nullptr;
583     }
584     const string str = src.scalar<ResourceHandle>()().SerializeAsString();
585     TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
586     std::memcpy(TF_TensorData(t), str.c_str(), str.size());
587     return t;
588   }
589   if (src.dtype() != DT_STRING) {
590     TensorBuffer* buf = TensorCApi::Buffer(src);
591     buf->Ref();
592     return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
593                          buf};
594   }
595   // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
596   // encoded sequence of strings.
597 
598   // Compute bytes needed for encoding.
599   size_t size = 0;
600   const auto& srcarray = src.flat<string>();
601   for (int i = 0; i < srcarray.size(); ++i) {
602     const string& s = srcarray(i);
603     // uint64 starting_offset, TF_StringEncode-d string.
604     size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
605   }
606 
607   // Encode all strings.
608   char* base = new char[size];
609   char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
610   char* dst = data_start;  // Where next string is encoded.
611   size_t dst_len = size - static_cast<size_t>(data_start - base);
612   tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
613   for (int i = 0; i < srcarray.size(); ++i) {
614     *offsets = (dst - data_start);
615     offsets++;
616     const string& s = srcarray(i);
617     size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
618     if (TF_GetCode(status) != TF_OK) {
619       status->status = InvalidArgument(
620           "invalid string tensor encoding (string #", i, " of ",
621           srcarray.size(), "): ", status->status.error_message());
622       delete[] base;
623       return nullptr;
624     }
625     dst += consumed;
626     dst_len -= consumed;
627   }
628   if (dst != base + size) {
629     status->status = InvalidArgument(
630         "invalid string tensor encoding (decoded ", (dst - base),
631         " bytes, but the tensor is encoded in ", size, " bytes");
632     delete[] base;
633     return nullptr;
634   }
635 
636   auto dims = src.shape().dim_sizes();
637   std::vector<tensorflow::int64> dimvec(dims.size());
638   for (size_t i = 0; i < dims.size(); ++i) {
639     dimvec[i] = dims[i];
640   }
641   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
642                 "64-bit int types should match in size");
643   return TF_NewTensor(TF_STRING,
644                       reinterpret_cast<const int64_t*>(dimvec.data()),
645                       dimvec.size(), base, size, DeleteArray, base);
646 }
647 
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)648 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
649                        TF_Buffer* out) {
650   if (out->data != nullptr) {
651     return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
652   }
653   const size_t proto_size = in.ByteSizeLong();
654   void* buf = tensorflow::port::Malloc(proto_size);
655   if (buf == nullptr) {
656     return tensorflow::errors::ResourceExhausted(
657         "Failed to allocate memory to serialize message of type '",
658         in.GetTypeName(), "' and size ", proto_size);
659   }
660   // SerializeToArray takes size as an int.
661   // This next 'if' is a workaround till we update to depend on a version
662   // of protocol buffers that includes
663   // https://github.com/google/protobuf/pull/4739
664   if (proto_size > std::numeric_limits<int>::max()) {
665     return InvalidArgument("Cannot serialize protocol buffer of type ",
666                            in.GetTypeName(), " as the serialized size (",
667                            proto_size,
668                            "bytes) would be larger than the limit (",
669                            std::numeric_limits<int>::max(), " bytes)");
670   }
671   if (!in.SerializeToArray(buf, proto_size)) {
672     return InvalidArgument("Unable to serialize ", in.GetTypeName(),
673                            " protocol buffer, perhaps the serialized size (",
674                            proto_size, " bytes) is too large?");
675   }
676   out->data = buf;
677   out->length = proto_size;
678   out->data_deallocator = [](void* data, size_t length) {
679     tensorflow::port::Free(data);
680   };
681   return Status::OK();
682 }
683 
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)684 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
685                     const char* mutation_type) {
686   // If any session has already run this node_id, mark this session as
687   // unrunnable.
688   for (auto it : graph->sessions) {
689     mutex_lock session_lock(it.first->mu);
690     if (it.first->last_num_graph_nodes > op.node.id()) {
691       it.second = strings::StrCat(
692           "Operation '", op.node.DebugString(), "' was changed by ",
693           mutation_type,
694           " after it was run by a session. This mutation will have no effect, "
695           "and will trigger an error in the future. Either don't modify "
696           "nodes after running them or create a new session.");
697     }
698   }
699 }
700 
701 namespace {
702 
703 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)704 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
705     tensorflow::shape_inference::InferenceContext* ic, int num_dims,
706     const int64_t* dims) {
707   if (num_dims != -1) {
708     std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
709     dim_vec.reserve(num_dims);
710     for (int i = 0; i < num_dims; ++i) {
711       dim_vec.push_back(ic->MakeDim(dims[i]));
712     }
713     return ic->MakeShape(dim_vec);
714   } else {
715     return ic->UnknownShape();
716   }
717 }
718 
719 }  // namespace
720 
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)721 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
722                                            int num_shapes_and_types,
723                                            const int64_t** shapes,
724                                            const int* ranks,
725                                            const TF_DataType* types,
726                                            TF_Status* status) {
727   Node* node = &output.oper->node;
728 
729   mutex_lock l(graph->mu);
730   tensorflow::shape_inference::InferenceContext* ic =
731       graph->refiner.GetContext(node);
732   if (ic == nullptr) {
733     status->status =
734         InvalidArgument("Node ", node->name(), " was not found in the graph");
735     return;
736   }
737 
738   auto shape_and_type_vec =
739       std::vector<tensorflow::shape_inference::ShapeAndType>(
740           num_shapes_and_types);
741   for (int i = 0; i < num_shapes_and_types; ++i) {
742     tensorflow::shape_inference::ShapeHandle shape_handle =
743         ShapeHandleFromDims(ic, ranks[i], shapes[i]);
744     shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
745         shape_handle, static_cast<DataType>(types[i]));
746   }
747 
748   ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
749 }
750 
751 // Helpers for loading a TensorFlow plugin (a .so file).
752 Status LoadLibrary(const char* library_filename, void** result,
753                    const void** buf, size_t* len);
754 
755 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
756 // directly, instead of requiring us to serialize to a GraphDef and
757 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)758 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
759   if (session->graph != nullptr) {
760     // Take the graph lock before the session lock to avoid deadlock. This is
761     // safe since session->graph does not change.
762     session->graph->mu.lock();
763     mutex_lock session_lock(session->mu);
764     const Graph& graph = session->graph->graph;
765 
766     const string& mutation_warning = session->graph->sessions[session];
767     if (!mutation_warning.empty()) {
768       // TODO(b/74949947): turn this back into an error status
769       LOG(WARNING) << mutation_warning;
770       session->graph->sessions[session].clear();
771     }
772 
773     const auto num_nodes = graph.num_node_ids();
774     if (session->last_num_graph_nodes < num_nodes) {
775       // TODO(nolivia): check this on a subset of the graph instead of all of
776       // it.
777       status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
778       if (TF_GetCode(status) != TF_OK) {
779         session->graph->mu.unlock();
780         return false;
781       }
782 
783       GraphDef graph_def;
784       *graph_def.mutable_versions() = graph.versions();
785       // Fill graph_def with nodes with ids in the range
786       // [session->last_num_graph_nodes, num_nodes), that is the nodes
787       // added since the last TF_SessionRun() call.
788       for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
789         Node* const node = graph.FindNodeId(id);
790         if (node != nullptr && node->IsOp()) {
791           NodeDef* const node_def = graph_def.add_node();
792           *node_def = node->def();
793         }
794       }
795       *graph_def.mutable_library() = graph.flib_def().ToProto();
796       session->graph->mu.unlock();
797       status->status = session->session->Extend(graph_def);
798       if (TF_GetCode(status) != TF_OK) {
799         // Contract is we always delete input_values[i].
800         return false;
801       }
802       // Note: session->session is not modified if Extend() fails, so
803       // we only set last_num_graph_nodes if it succeeds.
804       session->last_num_graph_nodes = num_nodes;
805     } else {
806       session->graph->mu.unlock();
807     }
808   }
809   return true;
810 }
811 
812 }  // namespace tensorflow
813 
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)814 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
815                          TF_Status* status) {
816   status->status = Status::OK();
817   for (int i = 0; i < noutputs; ++i) {
818     c_outputs[i] = nullptr;
819   }
820 }
821 
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)822 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
823                           std::vector<std::pair<string, Tensor>>* input_pairs,
824                           TF_Status* status) {
825   const int ninputs = input_pairs->size();
826   for (int i = 0; i < ninputs; ++i) {
827     status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
828     if (TF_GetCode(status) != TF_OK) return false;
829   }
830   return true;
831 }
832 
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)833 static void TF_Run_Helper(
834     Session* session, const char* handle, const TF_Buffer* run_options,
835     // Input tensors
836     const std::vector<std::pair<string, Tensor>>& input_pairs,
837     // Output tensors
838     const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
839     // Target nodes
840     const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
841     TF_Status* status) {
842   const int noutputs = output_tensor_names.size();
843   std::vector<Tensor> outputs(noutputs);
844   Status result;
845 
846   if (handle == nullptr) {
847     RunOptions run_options_proto;
848     if (run_options != nullptr && !run_options_proto.ParseFromArray(
849                                       run_options->data, run_options->length)) {
850       status->status = InvalidArgument("Unparseable RunOptions proto");
851       return;
852     }
853     if (run_metadata != nullptr && run_metadata->data != nullptr) {
854       status->status =
855           InvalidArgument("Passing non-empty run_metadata is invalid.");
856       return;
857     }
858 
859     RunMetadata run_metadata_proto;
860     result = session->Run(run_options_proto, input_pairs, output_tensor_names,
861                           target_oper_names, &outputs, &run_metadata_proto);
862 
863     // Serialize back to upstream client, who now owns the new buffer
864     if (run_metadata != nullptr) {
865       status->status = MessageToBuffer(run_metadata_proto, run_metadata);
866       if (TF_GetCode(status) != TF_OK) return;
867     }
868   } else {
869     // NOTE(zongheng): PRun does not support RunOptions yet.
870     result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
871   }
872   if (!result.ok()) {
873     status->status = result;
874     return;
875   }
876 
877   // Store results in c_outputs[]
878   for (int i = 0; i < noutputs; ++i) {
879     const Tensor& src = outputs[i];
880     if (!src.IsInitialized() || src.NumElements() == 0) {
881       c_outputs[i] =
882           EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
883       continue;
884     }
885     c_outputs[i] = TF_TensorFromTensor(src, status);
886     if (TF_GetCode(status) != TF_OK) return;
887   }
888 }
889 
890 extern "C" {
891 
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)892 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
893             // Input tensors
894             const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
895             // Output tensors
896             const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
897             // Target nodes
898             const char** c_target_oper_names, int ntargets,
899             TF_Buffer* run_metadata, TF_Status* status) {
900   TF_Run_Setup(noutputs, c_outputs, status);
901   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
902   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
903   for (int i = 0; i < ninputs; ++i) {
904     input_pairs[i].first = c_input_names[i];
905   }
906   std::vector<string> output_names(noutputs);
907   for (int i = 0; i < noutputs; ++i) {
908     output_names[i] = c_output_names[i];
909   }
910   std::vector<string> target_oper_names(ntargets);
911   for (int i = 0; i < ntargets; ++i) {
912     target_oper_names[i] = c_target_oper_names[i];
913   }
914   TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
915                 c_outputs, target_oper_names, run_metadata, status);
916 }
917 
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)918 void TF_PRunSetup(TF_DeprecatedSession* s,
919                   // Input names
920                   const char** c_input_names, int ninputs,
921                   // Output names
922                   const char** c_output_names, int noutputs,
923                   // Target nodes
924                   const char** c_target_oper_names, int ntargets,
925                   const char** handle, TF_Status* status) {
926   *handle = nullptr;
927 
928   std::vector<string> input_names(ninputs);
929   std::vector<string> output_names(noutputs);
930   std::vector<string> target_oper_names(ntargets);
931   for (int i = 0; i < ninputs; ++i) {
932     input_names[i] = c_input_names[i];
933   }
934   for (int i = 0; i < noutputs; ++i) {
935     output_names[i] = c_output_names[i];
936   }
937   for (int i = 0; i < ntargets; ++i) {
938     target_oper_names[i] = c_target_oper_names[i];
939   }
940   string new_handle;
941   status->status = s->session->PRunSetup(input_names, output_names,
942                                          target_oper_names, &new_handle);
943   if (TF_GetCode(status) == TF_OK) {
944     char* buf = new char[new_handle.size() + 1];
945     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
946     *handle = buf;
947   }
948 }
949 
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)950 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
951              // Input tensors
952              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
953              // Output tensors
954              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
955              // Target nodes
956              const char** c_target_oper_names, int ntargets,
957              TF_Status* status) {
958   TF_Run_Setup(noutputs, c_outputs, status);
959   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
960   if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
961   for (int i = 0; i < ninputs; ++i) {
962     input_pairs[i].first = c_input_names[i];
963   }
964 
965   std::vector<string> output_names(noutputs);
966   for (int i = 0; i < noutputs; ++i) {
967     output_names[i] = c_output_names[i];
968   }
969   std::vector<string> target_oper_names(ntargets);
970   for (int i = 0; i < ntargets; ++i) {
971     target_oper_names[i] = c_target_oper_names[i];
972   }
973   TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
974                 c_outputs, target_oper_names, nullptr, status);
975 }
976 
TF_LoadLibrary(const char * library_filename,TF_Status * status)977 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
978   TF_Library* lib_handle = new TF_Library;
979   status->status = tensorflow::LoadLibrary(
980       library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
981       &lib_handle->op_list.length);
982   if (TF_GetCode(status) != TF_OK) {
983     delete lib_handle;
984     return nullptr;
985   }
986   return lib_handle;
987 }
988 
TF_GetOpList(TF_Library * lib_handle)989 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
990 
TF_DeleteLibraryHandle(TF_Library * lib_handle)991 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
992   if (lib_handle == nullptr) return;
993   tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
994   delete lib_handle;
995 }
996 
TF_GetAllOpList()997 TF_Buffer* TF_GetAllOpList() {
998   std::vector<tensorflow::OpDef> op_defs;
999   tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
1000   tensorflow::OpList op_list;
1001   for (const auto& op : op_defs) {
1002     *(op_list.add_op()) = op;
1003   }
1004   TF_Buffer* ret = TF_NewBuffer();
1005   TF_CHECK_OK(MessageToBuffer(op_list, ret));
1006   return ret;
1007 }
1008 
1009 // --------------------------------------------------------------------------
1010 // ListDevices & SessionListDevices API
1011 
TF_DeleteDeviceList(TF_DeviceList * s)1012 void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
1013 
TF_SessionListDevices(TF_Session * session,TF_Status * status)1014 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
1015   TF_DeviceList* response = new TF_DeviceList;
1016   status->status = session->session->ListDevices(&response->response);
1017   return response;
1018 }
1019 
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)1020 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
1021                                                TF_Status* status) {
1022   TF_DeviceList* response = new TF_DeviceList;
1023   status->status = session->session->ListDevices(&response->response);
1024   return response;
1025 }
1026 
TF_DeviceListCount(const TF_DeviceList * list)1027 int TF_DeviceListCount(const TF_DeviceList* list) {
1028   return list->response.size();
1029 }
1030 
1031 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
1032   return_type method_name(const TF_DeviceList* list, const int index,     \
1033                           TF_Status* status) {                            \
1034     if (list == nullptr) {                                                \
1035       status->status = InvalidArgument("list is null!");                  \
1036       return err_val;                                                     \
1037     }                                                                     \
1038     if (index < 0 || index >= list->response.size()) {                    \
1039       status->status = InvalidArgument("index out of bounds");            \
1040       return err_val;                                                     \
1041     }                                                                     \
1042     status->status = Status::OK();                                        \
1043     return list->response[index].accessor;                                \
1044   }
1045 
1046 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
1047 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
1048                      nullptr);
1049 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
1050 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
1051 
1052 #undef TF_DEVICELIST_METHOD
1053 
1054 }  // end extern "C"
1055 
1056 // --------------------------------------------------------------------------
1057 // New Graph and Session API
1058 
1059 // Helper functions -----------------------------------------------------------
1060 
1061 namespace {
1062 
ToOperation(Node * node)1063 TF_Operation* ToOperation(Node* node) {
1064   return static_cast<TF_Operation*>(static_cast<void*>(node));
1065 }
1066 
OutputName(const TF_Output & output)1067 string OutputName(const TF_Output& output) {
1068   return StrCat(output.oper->node.name(), ":", output.index);
1069 }
1070 
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)1071 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
1072                                           const char* attr_name,
1073                                           TF_Status* status) {
1074   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
1075   if (attr == nullptr) {
1076     status->status = InvalidArgument("Operation '", oper->node.name(),
1077                                      "' has no attr named '", attr_name, "'.");
1078   }
1079   return attr;
1080 }
1081 
ToTensorId(const TF_Output & output)1082 TensorId ToTensorId(const TF_Output& output) {
1083   return TensorId(output.oper->node.name(), output.index);
1084 }
1085 
1086 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)1087 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
1088                                                      int n) {
1089   std::vector<tensorflow::Output> outputs(n);
1090   for (int i = 0; i < n; ++i) {
1091     outputs[i] =
1092         tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
1093   }
1094   return outputs;
1095 }
1096 
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)1097 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
1098                           TF_Output* tf_outputs) {
1099   for (int i = 0; i < outputs.size(); i++) {
1100     tf_outputs[i].oper = ToOperation(outputs[i].node());
1101     tf_outputs[i].index = outputs[i].index();
1102   }
1103 }
1104 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1105 
1106 }  // namespace
1107 
1108 // Shape functions -----------------------------------------------------------
1109 
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)1110 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1111                             const int64_t* dims, const int num_dims,
1112                             TF_Status* status) {
1113   Node* node = &output.oper->node;
1114 
1115   mutex_lock l(graph->mu);
1116   tensorflow::shape_inference::InferenceContext* ic =
1117       graph->refiner.GetContext(node);
1118   if (ic == nullptr) {
1119     status->status =
1120         InvalidArgument("Node ", node->name(), " was not found in the graph");
1121     return;
1122   }
1123   tensorflow::shape_inference::ShapeHandle new_shape =
1124       tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1125   status->status = graph->refiner.SetShape(node, output.index, new_shape);
1126 }
1127 
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)1128 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1129                              TF_Status* status) {
1130   Node* node = &output.oper->node;
1131 
1132   mutex_lock l(graph->mu);
1133   tensorflow::shape_inference::InferenceContext* ic =
1134       graph->refiner.GetContext(node);
1135   if (ic == nullptr) {
1136     status->status =
1137         InvalidArgument("Node ", node->name(), " was not found in the graph");
1138     return -1;
1139   }
1140 
1141   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1142 
1143   // Unknown rank means the number of dimensions is -1.
1144   if (!ic->RankKnown(shape)) {
1145     return -1;
1146   }
1147 
1148   return ic->Rank(shape);
1149 }
1150 
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)1151 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1152                             int num_dims, TF_Status* status) {
1153   Node* node = &output.oper->node;
1154 
1155   mutex_lock l(graph->mu);
1156   tensorflow::shape_inference::InferenceContext* ic =
1157       graph->refiner.GetContext(node);
1158   if (ic == nullptr) {
1159     status->status =
1160         InvalidArgument("Node ", node->name(), " was not found in the graph");
1161     return;
1162   }
1163 
1164   tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1165 
1166   int rank = -1;
1167   if (ic->RankKnown(shape)) {
1168     rank = ic->Rank(shape);
1169   }
1170 
1171   if (num_dims != rank) {
1172     status->status = InvalidArgument("Expected rank is ", num_dims,
1173                                      " but actual rank is ", rank);
1174     return;
1175   }
1176 
1177   if (num_dims == 0) {
1178     // Output shape is a scalar.
1179     return;
1180   }
1181 
1182   // Rank is greater than 0, so fill in the values, if known, and
1183   // -1 for unknown values.
1184   for (int i = 0; i < num_dims; ++i) {
1185     auto dim = ic->Dim(shape, i);
1186     tensorflow::int64 value = -1;
1187     if (ic->ValueKnown(dim)) {
1188       value = ic->Value(dim);
1189     }
1190     dims[i] = value;
1191   }
1192 }
1193 
1194 // TF_OperationDescription functions ------------------------------------------
1195 
1196 extern "C" {
1197 
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)1198 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1199                                                       const char* op_type,
1200                                                       const char* oper_name)
1201     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1202   return new TF_OperationDescription(graph, op_type, oper_name);
1203 }
1204 
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)1205 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1206                                          const char* oper_name) {
1207   mutex_lock l(graph->mu);
1208   return TF_NewOperationLocked(graph, op_type, oper_name);
1209 }
1210 
TF_SetDevice(TF_OperationDescription * desc,const char * device)1211 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1212   desc->node_builder.Device(device);
1213 }
1214 
TF_AddInput(TF_OperationDescription * desc,TF_Output input)1215 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1216   desc->node_builder.Input(&input.oper->node, input.index);
1217 }
1218 
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)1219 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1220                      int num_inputs) {
1221   std::vector<NodeBuilder::NodeOut> input_list;
1222   input_list.reserve(num_inputs);
1223   for (int i = 0; i < num_inputs; ++i) {
1224     input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1225   }
1226   desc->node_builder.Input(input_list);
1227 }
1228 
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)1229 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1230   desc->node_builder.ControlInput(&input->node);
1231 }
1232 
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)1233 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1234   desc->colocation_constraints.emplace(
1235       StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1236 }
1237 
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)1238 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1239                       const void* value, size_t length) {
1240   tensorflow::StringPiece s(static_cast<const char*>(value), length);
1241   desc->node_builder.Attr(attr_name, s);
1242 }
1243 
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1244 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1245                           const void* const* values, const size_t* lengths,
1246                           int num_values) {
1247   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1248     desc->colocation_constraints.clear();
1249     for (int i = 0; i < num_values; ++i) {
1250       desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1251                                            lengths[i]);
1252     }
1253   } else {
1254     std::vector<tensorflow::StringPiece> v;
1255     v.reserve(num_values);
1256     for (int i = 0; i < num_values; ++i) {
1257       v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1258     }
1259     desc->node_builder.Attr(attr_name, v);
1260   }
1261 }
1262 
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)1263 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1264                    int64_t value) {
1265   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1266                 "64-bit int types should match in size");
1267   desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1268 }
1269 
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)1270 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1271                        const int64_t* values, int num_values) {
1272   static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1273                 "64-bit int types should match in size");
1274   desc->node_builder.Attr(
1275       attr_name,
1276       ArraySlice<const tensorflow::int64>(
1277           reinterpret_cast<const tensorflow::int64*>(values), num_values));
1278 }
1279 
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)1280 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1281                      float value) {
1282   desc->node_builder.Attr(attr_name, value);
1283 }
1284 
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)1285 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1286                          const float* values, int num_values) {
1287   desc->node_builder.Attr(attr_name,
1288                           ArraySlice<const float>(values, num_values));
1289 }
1290 
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)1291 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1292                     unsigned char value) {
1293   desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1294 }
1295 
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)1296 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1297                         const unsigned char* values, int num_values) {
1298   std::unique_ptr<bool[]> b(new bool[num_values]);
1299   for (int i = 0; i < num_values; ++i) {
1300     b[i] = values[i];
1301   }
1302   desc->node_builder.Attr(attr_name,
1303                           ArraySlice<const bool>(b.get(), num_values));
1304 }
1305 
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)1306 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1307                     TF_DataType value) {
1308   desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1309 }
1310 
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)1311 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1312                         const TF_DataType* values, int num_values) {
1313   desc->node_builder.Attr(
1314       attr_name, ArraySlice<const DataType>(
1315                      reinterpret_cast<const DataType*>(values), num_values));
1316 }
1317 
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)1318 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
1319                            const char* placeholder) {
1320   tensorflow::AttrValue attr_value;
1321   attr_value.set_placeholder(placeholder);
1322   desc->node_builder.Attr(attr_name, attr_value);
1323 }
1324 
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)1325 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1326                         const char* value, size_t length) {
1327   tensorflow::NameAttrList func_name;
1328   func_name.set_name(string(value, value + length));
1329   desc->node_builder.Attr(attr_name, func_name);
1330 }
1331 
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)1332 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1333                      const int64_t* dims, int num_dims) {
1334   PartialTensorShape shape;
1335   if (num_dims >= 0) {
1336     static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1337                   "64-bit int types should match in size");
1338     shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1339         reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1340   }
1341   desc->node_builder.Attr(attr_name, shape);
1342 }
1343 
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)1344 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1345                          const int64_t* const* dims, const int* num_dims,
1346                          int num_shapes) {
1347   std::vector<PartialTensorShape> shapes;
1348   shapes.reserve(num_shapes);
1349   for (int i = 0; i < num_shapes; ++i) {
1350     if (num_dims[i] < 0) {
1351       shapes.emplace_back();
1352     } else {
1353       static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1354                     "64-bit int types should match in size");
1355       shapes.emplace_back(ArraySlice<tensorflow::int64>(
1356           reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1357     }
1358   }
1359   desc->node_builder.Attr(attr_name, shapes);
1360 }
1361 
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1362 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1363                                 const char* attr_name, const void* proto,
1364                                 size_t proto_len, TF_Status* status) {
1365   // shape.ParseFromArray takes an int as length, this function takes size_t,
1366   // make sure there is no information loss.
1367   if (proto_len > std::numeric_limits<int>::max()) {
1368     status->status = InvalidArgument(
1369         "proto_len (", proto_len,
1370         " bytes) is too large to be parsed by the protocol buffer library");
1371     return;
1372   }
1373   TensorShapeProto shape;
1374   if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1375     desc->node_builder.Attr(attr_name, shape);
1376     status->status = Status::OK();
1377   } else {
1378     status->status = InvalidArgument("Unparseable TensorShapeProto");
1379   }
1380 }
1381 
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)1382 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1383                                     const char* attr_name,
1384                                     const void* const* protos,
1385                                     const size_t* proto_lens, int num_shapes,
1386                                     TF_Status* status) {
1387   std::vector<TensorShapeProto> shapes;
1388   shapes.resize(num_shapes);
1389   for (int i = 0; i < num_shapes; ++i) {
1390     if (proto_lens[i] > std::numeric_limits<int>::max()) {
1391       status->status = InvalidArgument(
1392           "length of element ", i, " in the list (", proto_lens[i],
1393           " bytes) is too large to be parsed by the protocol buffer library");
1394       return;
1395     }
1396     if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1397       status->status =
1398           InvalidArgument("Unparseable TensorShapeProto at index ", i);
1399       return;
1400     }
1401   }
1402   desc->node_builder.Attr(attr_name, shapes);
1403   status->status = Status::OK();
1404 }
1405 
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1406 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1407                       TF_Tensor* value, TF_Status* status) {
1408   Tensor t;
1409   status->status = TF_TensorToTensor(value, &t);
1410   if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
1411 }
1412 
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1413 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1414                           TF_Tensor* const* values, int num_values,
1415                           TF_Status* status) {
1416   status->status = Status::OK();
1417   std::vector<Tensor> t;
1418   t.reserve(num_values);
1419 
1420   for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
1421     Tensor v;
1422     status->status = TF_TensorToTensor(values[i], &v);
1423     t.emplace_back(v);
1424   }
1425 
1426   if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
1427 }
1428 
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1429 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1430                           const void* proto, size_t proto_len,
1431                           TF_Status* status) {
1432   tensorflow::AttrValue attr_value;
1433   if (!attr_value.ParseFromArray(proto, proto_len)) {
1434     status->status = InvalidArgument("Unparseable AttrValue proto");
1435     return;
1436   }
1437 
1438   if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1439     if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1440         attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1441       status->status =
1442           InvalidArgument("Expected \"list\" field for \"",
1443                           tensorflow::kColocationAttrName, "\" attribute");
1444       return;
1445     }
1446     desc->colocation_constraints.clear();
1447     for (const string& location : attr_value.list().s()) {
1448       desc->colocation_constraints.insert(location);
1449     }
1450   } else {
1451     desc->node_builder.Attr(attr_name, attr_value);
1452   }
1453 
1454   status->status = Status::OK();
1455 }
1456 
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1457 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1458                                               TF_Status* status)
1459     EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1460   Node* ret = nullptr;
1461 
1462   if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1463     status->status = InvalidArgument("Duplicate node name in graph: '",
1464                                      desc->node_builder.node_name(), "'");
1465   } else {
1466     if (!desc->colocation_constraints.empty()) {
1467       desc->node_builder.Attr(
1468           tensorflow::kColocationAttrName,
1469           std::vector<string>(desc->colocation_constraints.begin(),
1470                               desc->colocation_constraints.end()));
1471     }
1472     status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1473 
1474     if (TF_GetCode(status) == TF_OK) {
1475       // Run shape inference function for newly added node.
1476       status->status = desc->graph->refiner.AddNode(ret);
1477     }
1478     if (TF_GetCode(status) == TF_OK) {
1479       // Add the node to the name-to-node mapping.
1480       desc->graph->name_map[ret->name()] = ret;
1481     } else if (ret != nullptr) {
1482       desc->graph->graph.RemoveNode(ret);
1483       ret = nullptr;
1484     }
1485   }
1486 
1487   delete desc;
1488 
1489   return ToOperation(ret);
1490 }
1491 
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1492 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1493                                  TF_Status* status) {
1494   mutex_lock l(desc->graph->mu);
1495   return TF_FinishOperationLocked(desc, status);
1496 }
1497 
1498 // TF_Operation functions
1499 // ----------------------------------------------------------
1500 
TF_OperationName(TF_Operation * oper)1501 const char* TF_OperationName(TF_Operation* oper) {
1502   return oper->node.name().c_str();
1503 }
1504 
TF_OperationOpType(TF_Operation * oper)1505 const char* TF_OperationOpType(TF_Operation* oper) {
1506   return oper->node.type_string().c_str();
1507 }
1508 
TF_OperationDevice(TF_Operation * oper)1509 const char* TF_OperationDevice(TF_Operation* oper) {
1510   return oper->node.requested_device().c_str();
1511 }
1512 
TF_OperationNumOutputs(TF_Operation * oper)1513 int TF_OperationNumOutputs(TF_Operation* oper) {
1514   return oper->node.num_outputs();
1515 }
1516 
TF_OperationOutputType(TF_Output oper_out)1517 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1518   return static_cast<TF_DataType>(
1519       oper_out.oper->node.output_type(oper_out.index));
1520 }
1521 
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1522 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1523                                  TF_Status* status) {
1524   NameRangeMap name_ranges;
1525   status->status =
1526       NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1527   if (TF_GetCode(status) != TF_OK) return -1;
1528   auto iter = name_ranges.find(arg_name);
1529   if (iter == name_ranges.end()) {
1530     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1531     return -1;
1532   }
1533   return iter->second.second - iter->second.first;
1534 }
1535 
TF_OperationNumInputs(TF_Operation * oper)1536 int TF_OperationNumInputs(TF_Operation* oper) {
1537   return oper->node.num_inputs();
1538 }
1539 
TF_OperationInputType(TF_Input oper_in)1540 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1541   return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1542 }
1543 
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1544 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1545                                 TF_Status* status) {
1546   NameRangeMap name_ranges;
1547   status->status =
1548       NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1549   if (TF_GetCode(status) != TF_OK) return -1;
1550   auto iter = name_ranges.find(arg_name);
1551   if (iter == name_ranges.end()) {
1552     status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1553     return -1;
1554   }
1555   return iter->second.second - iter->second.first;
1556 }
1557 
TF_OperationInput(TF_Input oper_in)1558 TF_Output TF_OperationInput(TF_Input oper_in) {
1559   const tensorflow::Edge* edge;
1560   Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1561   if (!s.ok()) {
1562     return {nullptr, -1};
1563   }
1564 
1565   return {ToOperation(edge->src()), edge->src_output()};
1566 }
1567 
TF_OperationOutputNumConsumers(TF_Output oper_out)1568 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1569   int count = 0;
1570   for (const auto* edge : oper_out.oper->node.out_edges()) {
1571     if (edge->src_output() == oper_out.index) {
1572       ++count;
1573     }
1574   }
1575   return count;
1576 }
1577 
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1578 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1579                                 int max_consumers) {
1580   int count = 0;
1581   for (const auto* edge : oper_out.oper->node.out_edges()) {
1582     if (edge->src_output() == oper_out.index) {
1583       if (count < max_consumers) {
1584         consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1585       }
1586       ++count;
1587     }
1588   }
1589   return count;
1590 }
1591 
TF_OperationNumControlInputs(TF_Operation * oper)1592 int TF_OperationNumControlInputs(TF_Operation* oper) {
1593   int count = 0;
1594   for (const auto* edge : oper->node.in_edges()) {
1595     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1596       ++count;
1597     }
1598   }
1599   return count;
1600 }
1601 
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1602 int TF_OperationGetControlInputs(TF_Operation* oper,
1603                                  TF_Operation** control_inputs,
1604                                  int max_control_inputs) {
1605   int count = 0;
1606   for (const auto* edge : oper->node.in_edges()) {
1607     if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1608       if (count < max_control_inputs) {
1609         control_inputs[count] = ToOperation(edge->src());
1610       }
1611       ++count;
1612     }
1613   }
1614   return count;
1615 }
1616 
TF_OperationNumControlOutputs(TF_Operation * oper)1617 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1618   int count = 0;
1619   for (const auto* edge : oper->node.out_edges()) {
1620     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1621       ++count;
1622     }
1623   }
1624   return count;
1625 }
1626 
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1627 int TF_OperationGetControlOutputs(TF_Operation* oper,
1628                                   TF_Operation** control_outputs,
1629                                   int max_control_outputs) {
1630   int count = 0;
1631   for (const auto* edge : oper->node.out_edges()) {
1632     if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1633       if (count < max_control_outputs) {
1634         control_outputs[count] = ToOperation(edge->dst());
1635       }
1636       ++count;
1637     }
1638   }
1639   return count;
1640 }
1641 
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1642 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1643                                             const char* attr_name,
1644                                             TF_Status* status) {
1645   TF_AttrMetadata metadata;
1646   const auto* attr = GetAttrValue(oper, attr_name, status);
1647   if (TF_GetCode(status) != TF_OK) return metadata;
1648   switch (attr->value_case()) {
1649 #define SINGLE_CASE(kK, attr_type, size_expr) \
1650   case tensorflow::AttrValue::kK:             \
1651     metadata.is_list = 0;                     \
1652     metadata.list_size = -1;                  \
1653     metadata.type = attr_type;                \
1654     metadata.total_size = size_expr;          \
1655     break;
1656 
1657     SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1658     SINGLE_CASE(kI, TF_ATTR_INT, -1);
1659     SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1660     SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1661     SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1662     SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1663                 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1664     SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1665 #undef SINGLE_CASE
1666 
1667     case tensorflow::AttrValue::kList:
1668       metadata.is_list = 1;
1669       metadata.list_size = 0;
1670       metadata.total_size = -1;
1671 #define LIST_CASE(field, attr_type, ...)              \
1672   if (attr->list().field##_size() > 0) {              \
1673     metadata.type = attr_type;                        \
1674     metadata.list_size = attr->list().field##_size(); \
1675     __VA_ARGS__;                                      \
1676     break;                                            \
1677   }
1678 
1679       LIST_CASE(
1680           s, TF_ATTR_STRING, metadata.total_size = 0;
1681           for (int i = 0; i < attr->list().s_size();
1682                ++i) { metadata.total_size += attr->list().s(i).size(); });
1683       LIST_CASE(i, TF_ATTR_INT);
1684       LIST_CASE(f, TF_ATTR_FLOAT);
1685       LIST_CASE(b, TF_ATTR_BOOL);
1686       LIST_CASE(type, TF_ATTR_TYPE);
1687       LIST_CASE(
1688           shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1689           for (int i = 0; i < attr->list().shape_size(); ++i) {
1690             const auto& s = attr->list().shape(i);
1691             metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1692           });
1693       LIST_CASE(tensor, TF_ATTR_TENSOR);
1694       LIST_CASE(tensor, TF_ATTR_FUNC);
1695 #undef LIST_CASE
1696       // All lists empty, determine the type from the OpDef.
1697       if (metadata.list_size == 0) {
1698         for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1699           const auto& a = oper->node.op_def().attr(i);
1700           if (a.name().compare(attr_name) != 0) continue;
1701           const string& typestr = a.type();
1702           if (typestr == "list(string)") {
1703             metadata.type = TF_ATTR_STRING;
1704           } else if (typestr == "list(int)") {
1705             metadata.type = TF_ATTR_INT;
1706           } else if (typestr == "list(float)") {
1707             metadata.type = TF_ATTR_FLOAT;
1708           } else if (typestr == "list(bool)") {
1709             metadata.type = TF_ATTR_BOOL;
1710           } else if (typestr == "list(type)") {
1711             metadata.type = TF_ATTR_TYPE;
1712           } else if (typestr == "list(shape)") {
1713             metadata.type = TF_ATTR_SHAPE;
1714           } else if (typestr == "list(tensor)") {
1715             metadata.type = TF_ATTR_TENSOR;
1716           } else if (typestr == "list(func)") {
1717             metadata.type = TF_ATTR_FUNC;
1718           } else {
1719             status->status = InvalidArgument(
1720                 "Attribute '", attr_name,
1721                 "' has an empty value of an unrecognized type '", typestr, "'");
1722             return metadata;
1723           }
1724         }
1725       }
1726       break;
1727 
1728     case tensorflow::AttrValue::kPlaceholder:
1729       metadata.is_list = 0;
1730       metadata.list_size = -1;
1731       metadata.type = TF_ATTR_PLACEHOLDER;
1732       metadata.total_size = -1;
1733       break;
1734 
1735     case tensorflow::AttrValue::kFunc:
1736       metadata.is_list = 0;
1737       metadata.list_size = -1;
1738       metadata.type = TF_ATTR_FUNC;
1739       metadata.total_size = -1;
1740       break;
1741 
1742     case tensorflow::AttrValue::VALUE_NOT_SET:
1743       status->status =
1744           InvalidArgument("Attribute '", attr_name, "' has no value set");
1745       break;
1746   }
1747   return metadata;
1748 }
1749 
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1750 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1751                                void* value, size_t max_length,
1752                                TF_Status* status) {
1753   const auto* attr = GetAttrValue(oper, attr_name, status);
1754   if (TF_GetCode(status) != TF_OK) return;
1755   if (attr->value_case() != tensorflow::AttrValue::kS) {
1756     status->status =
1757         InvalidArgument("Attribute '", attr_name, "' is not a string");
1758     return;
1759   }
1760   if (max_length <= 0) {
1761     return;
1762   }
1763   const auto& s = attr->s();
1764   std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1765 }
1766 
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1767 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1768                                    void** values, size_t* lengths,
1769                                    int max_values, void* storage,
1770                                    size_t storage_size, TF_Status* status) {
1771   const auto* attr = GetAttrValue(oper, attr_name, status);
1772   if (TF_GetCode(status) != TF_OK) return;
1773   if (attr->value_case() != tensorflow::AttrValue::kList) {
1774     status->status =
1775         InvalidArgument("Value for '", attr_name, "' is not a list");
1776     return;
1777   }
1778   const auto len = std::min(max_values, attr->list().s_size());
1779   char* p = static_cast<char*>(storage);
1780   for (int i = 0; i < len; ++i) {
1781     const string& s = attr->list().s(i);
1782     values[i] = p;
1783     lengths[i] = s.size();
1784     if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1785       status->status = InvalidArgument(
1786           "Not enough storage to hold the requested list of strings");
1787       return;
1788     }
1789     memcpy(values[i], s.data(), s.size());
1790     p += s.size();
1791   }
1792 }
1793 
1794 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1795   void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1796             TF_Status* status) {                                             \
1797     cpp_type v;                                                              \
1798     status->status =                                                         \
1799         tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1800     *value = static_cast<c_type>(v);                                         \
1801   }                                                                          \
1802   void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1803                   int max_values, TF_Status* status) {                       \
1804     const auto* attr = GetAttrValue(oper, attr_name, status);                \
1805     if (TF_GetCode(status) != TF_OK) return;                                 \
1806     if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1807       status->status =                                                       \
1808           InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1809       return;                                                                \
1810     }                                                                        \
1811     const auto len = std::min(max_values, attr->list().list_field##_size()); \
1812     for (int i = 0; i < len; ++i) {                                          \
1813       values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1814     }                                                                        \
1815   }
1816 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1817 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1818 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1819 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1820 #undef DEFINE_GETATTR
1821 
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1822 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1823                               int64_t* value, int num_dims, TF_Status* status) {
1824   PartialTensorShape shape;
1825   status->status =
1826       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1827   if (TF_GetCode(status) != TF_OK) return;
1828   auto len = std::min(shape.dims(), num_dims);
1829   for (int i = 0; i < len; ++i) {
1830     value[i] = shape.dim_size(i);
1831   }
1832 }
1833 
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** values,int * num_dims,int max_values,int64_t * storage,int storage_size,TF_Status * status)1834 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1835                                   int64_t** values, int* num_dims,
1836                                   int max_values, int64_t* storage,
1837                                   int storage_size, TF_Status* status) {
1838   std::vector<PartialTensorShape> shapes;
1839   status->status =
1840       tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1841   if (TF_GetCode(status) != TF_OK) return;
1842   auto len = std::min(static_cast<int>(shapes.size()), max_values);
1843   int64_t* p = storage;
1844   int storage_left = storage_size;
1845   for (int i = 0; i < len; ++i) {
1846     // shapes[i].dims() == -1 for shapes with an unknown rank.
1847     int64_t n = shapes[i].dims();
1848     num_dims[i] = n;
1849     values[i] = p;
1850     if (n < 0) {
1851       continue;
1852     }
1853     if (storage_left < n) {
1854       status->status = InvalidArgument(
1855           "Not enough storage to hold the requested list of shapes");
1856       return;
1857     }
1858     storage_left -= n;
1859     for (int j = 0; j < n; ++j, ++p) {
1860       *p = shapes[i].dim_size(j);
1861     }
1862   }
1863 }
1864 
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1865 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1866                                          const char* attr_name,
1867                                          TF_Buffer* value, TF_Status* status) {
1868   const auto* attr = GetAttrValue(oper, attr_name, status);
1869   if (TF_GetCode(status) != TF_OK) return;
1870   if (attr->value_case() != tensorflow::AttrValue::kShape) {
1871     status->status =
1872         InvalidArgument("Value for '", attr_name, "' is not a shape.");
1873     return;
1874   }
1875   status->status = MessageToBuffer(attr->shape(), value);
1876 }
1877 
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1878 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1879                                              const char* attr_name,
1880                                              TF_Buffer** values, int max_values,
1881                                              TF_Status* status) {
1882   const auto* attr = GetAttrValue(oper, attr_name, status);
1883   if (TF_GetCode(status) != TF_OK) return;
1884   if (attr->value_case() != tensorflow::AttrValue::kList) {
1885     status->status =
1886         InvalidArgument("Value for '", attr_name, "' is not a list");
1887     return;
1888   }
1889   const auto len = std::min(max_values, attr->list().shape_size());
1890   for (int i = 0; i < len; ++i) {
1891     values[i] = TF_NewBuffer();
1892     status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1893     if (TF_GetCode(status) != TF_OK) {
1894       // Delete everything allocated to far, the operation has failed.
1895       for (int j = 0; j <= i; ++j) {
1896         TF_DeleteBuffer(values[j]);
1897       }
1898       return;
1899     }
1900   }
1901 }
1902 
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1903 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1904                                TF_Tensor** value, TF_Status* status) {
1905   *value = nullptr;
1906   Tensor t;
1907   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1908   if (TF_GetCode(status) != TF_OK) return;
1909   *value = TF_TensorFromTensor(t, status);
1910 }
1911 
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1912 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1913                                    TF_Tensor** values, int max_values,
1914                                    TF_Status* status) {
1915   std::vector<Tensor> ts;
1916   status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1917   if (TF_GetCode(status) != TF_OK) return;
1918   const auto len = std::min(max_values, static_cast<int>(ts.size()));
1919   for (int i = 0; i < len; ++i) {
1920     values[i] = TF_TensorFromTensor(ts[i], status);
1921   }
1922 }
1923 
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1924 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1925                                    TF_Buffer* output_attr_value,
1926                                    TF_Status* status) {
1927   const auto* attr = GetAttrValue(oper, attr_name, status);
1928   if (TF_GetCode(status) != TF_OK) return;
1929   status->status = MessageToBuffer(*attr, output_attr_value);
1930 }
1931 
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1932 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1933                            TF_Status* status) {
1934   status->status = MessageToBuffer(oper->node.def(), output_node_def);
1935 }
1936 
1937 // TF_Graph functions ---------------------------------------------------------
1938 
TF_Graph()1939 TF_Graph::TF_Graph()
1940     : graph(tensorflow::OpRegistry::Global()),
1941       refiner(graph.versions().producer(), graph.op_registry()),
1942       delete_requested(false),
1943       parent(nullptr),
1944       parent_inputs(nullptr) {}
1945 
TF_NewGraph()1946 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1947 
TF_DeleteGraph(TF_Graph * g)1948 void TF_DeleteGraph(TF_Graph* g) {
1949   if (g == nullptr) return;
1950   g->mu.lock();
1951   g->delete_requested = true;
1952   const bool del = g->sessions.empty();
1953   g->mu.unlock();
1954   if (del) delete g;
1955 }
1956 
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1957 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1958   mutex_lock l(graph->mu);
1959   auto iter = graph->name_map.find(oper_name);
1960   if (iter == graph->name_map.end()) {
1961     return nullptr;
1962   } else {
1963     return ToOperation(iter->second);
1964   }
1965 }
1966 
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1967 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1968   if (*pos == 0) {
1969     // Advance past the first sentinel nodes in every graph (the source & sink).
1970     *pos += 2;
1971   } else {
1972     // Advance to the next node.
1973     *pos += 1;
1974   }
1975 
1976   mutex_lock l(graph->mu);
1977   while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1978     Node* node = graph->graph.FindNodeId(*pos);
1979     // FindNodeId() returns nullptr for nodes that have been deleted.
1980     // We aren't currently allowing nodes to be deleted, but it is safer
1981     // to still check.
1982     if (node != nullptr) return ToOperation(node);
1983     *pos += 1;
1984   }
1985 
1986   // No more nodes.
1987   return nullptr;
1988 }
1989 
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1990 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1991                         TF_Status* status) {
1992   GraphDef def;
1993   {
1994     mutex_lock l(graph->mu);
1995     graph->graph.ToGraphDef(&def);
1996   }
1997   status->status = MessageToBuffer(def, output_graph_def);
1998 }
1999 
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)2000 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
2001                       TF_Buffer* output_op_def, TF_Status* status) {
2002   const OpDef* op_def;
2003   {
2004     mutex_lock l(graph->mu);
2005     status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
2006     if (TF_GetCode(status) != TF_OK) return;
2007   }
2008   status->status = MessageToBuffer(*op_def, output_op_def);
2009 }
2010 
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)2011 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
2012                       TF_Status* status) {
2013   VersionDef versions;
2014   {
2015     mutex_lock l(graph->mu);
2016     versions = graph->graph.versions();
2017   }
2018   status->status = MessageToBuffer(versions, output_version_def);
2019 }
2020 
TF_NewImportGraphDefOptions()2021 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
2022   return new TF_ImportGraphDefOptions;
2023 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)2024 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
2025   delete opts;
2026 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)2027 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
2028                                        const char* prefix) {
2029   opts->opts.prefix = prefix;
2030 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)2031 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
2032                                               const char* device) {
2033   opts->opts.default_device = device;
2034 }
2035 
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)2036 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
2037                                               unsigned char uniquify_names) {
2038   opts->opts.uniquify_names = uniquify_names;
2039 }
2040 
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)2041 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
2042                                                unsigned char uniquify_prefix) {
2043   opts->opts.uniquify_prefix = uniquify_prefix;
2044 }
2045 
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)2046 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
2047                                              const char* src_name,
2048                                              int src_index, TF_Output dst) {
2049   opts->tensor_id_data.push_back(src_name);
2050   const string& src_name_str = opts->tensor_id_data.back();
2051   // We don't need to store dst's name in tensor_id_data, since `dst` must
2052   // outlive the ImportGraphDef call.
2053   opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
2054 }
2055 
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)2056 void TF_ImportGraphDefOptionsRemapControlDependency(
2057     TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
2058   opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
2059       TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
2060 }
2061 
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)2062 extern void TF_ImportGraphDefOptionsAddControlDependency(
2063     TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
2064   opts->opts.control_dependencies.push_back(oper->node.name());
2065 }
2066 
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)2067 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
2068                                              const char* oper_name, int index) {
2069   opts->tensor_id_data.push_back(oper_name);
2070   const string& oper_name_str = opts->tensor_id_data.back();
2071   opts->opts.return_tensors.emplace_back(oper_name_str, index);
2072 }
2073 
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)2074 int TF_ImportGraphDefOptionsNumReturnOutputs(
2075     const TF_ImportGraphDefOptions* opts) {
2076   return opts->opts.return_tensors.size();
2077 }
2078 
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)2079 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
2080                                                 const char* oper_name) {
2081   opts->opts.return_nodes.push_back(oper_name);
2082 }
2083 
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)2084 int TF_ImportGraphDefOptionsNumReturnOperations(
2085     const TF_ImportGraphDefOptions* opts) {
2086   return opts->opts.return_nodes.size();
2087 }
2088 
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)2089 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
2090                                            int* num_outputs,
2091                                            TF_Output** outputs) {
2092   *num_outputs = results->return_tensors.size();
2093   *outputs = results->return_tensors.data();
2094 }
2095 
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)2096 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
2097                                               int* num_opers,
2098                                               TF_Operation*** opers) {
2099   *num_opers = results->return_nodes.size();
2100   *opers = results->return_nodes.data();
2101 }
2102 
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)2103 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
2104     TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
2105     const char*** src_names, int** src_indexes) {
2106   *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
2107   *src_names = results->missing_unused_key_names.data();
2108   *src_indexes = results->missing_unused_key_indexes.data();
2109 }
2110 
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)2111 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
2112   delete results;
2113 }
2114 
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)2115 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
2116                                       const TF_ImportGraphDefOptions* opts,
2117                                       TF_ImportGraphDefResults* tf_results,
2118                                       TF_Status* status)
2119     EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2120   const int last_node_id = graph->graph.num_node_ids();
2121   tensorflow::ImportGraphDefResults results;
2122   status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2123                                               &graph->refiner, &results);
2124   if (TF_GetCode(status) != TF_OK) return;
2125 
2126   // Add new nodes to name_map
2127   for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2128     auto* node = graph->graph.FindNodeId(i);
2129     if (node != nullptr) graph->name_map[node->name()] = node;
2130   }
2131 
2132   // Populate return_tensors
2133   DCHECK(tf_results->return_tensors.empty());
2134   tf_results->return_tensors.resize(results.return_tensors.size());
2135   for (int i = 0; i < results.return_tensors.size(); ++i) {
2136     tf_results->return_tensors[i].oper =
2137         ToOperation(results.return_tensors[i].first);
2138     tf_results->return_tensors[i].index = results.return_tensors[i].second;
2139   }
2140 
2141   // Populate return_nodes
2142   DCHECK(tf_results->return_nodes.empty());
2143   tf_results->return_nodes.resize(results.return_nodes.size());
2144   for (int i = 0; i < results.return_nodes.size(); ++i) {
2145     tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2146   }
2147 
2148   // Populate missing unused map keys
2149   DCHECK(tf_results->missing_unused_key_names.empty());
2150   DCHECK(tf_results->missing_unused_key_indexes.empty());
2151   DCHECK(tf_results->missing_unused_key_names_data.empty());
2152 
2153   size_t size = results.missing_unused_input_map_keys.size();
2154   tf_results->missing_unused_key_names.resize(size);
2155   tf_results->missing_unused_key_indexes.resize(size);
2156 
2157   for (int i = 0; i < size; ++i) {
2158     TensorId id = results.missing_unused_input_map_keys[i];
2159     tf_results->missing_unused_key_names_data.emplace_back(id.first);
2160     tf_results->missing_unused_key_names[i] =
2161         tf_results->missing_unused_key_names_data.back().c_str();
2162     tf_results->missing_unused_key_indexes[i] = id.second;
2163   }
2164 }
2165 
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2166 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2167     TF_Graph* graph, const TF_Buffer* graph_def,
2168     const TF_ImportGraphDefOptions* options, TF_Status* status) {
2169   GraphDef def;
2170   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
2171                                        graph_def->length)) {
2172     status->status = InvalidArgument("Invalid GraphDef");
2173     return nullptr;
2174   }
2175   auto results = new TF_ImportGraphDefResults();
2176   mutex_lock l(graph->mu);
2177   GraphImportGraphDefLocked(graph, def, options, results, status);
2178   if (TF_GetCode(status) != TF_OK) {
2179     delete results;
2180     return nullptr;
2181   }
2182   return results;
2183 }
2184 
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)2185 void TF_GraphImportGraphDefWithReturnOutputs(
2186     TF_Graph* graph, const TF_Buffer* graph_def,
2187     const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2188     int num_return_outputs, TF_Status* status) {
2189   if (num_return_outputs != options->opts.return_tensors.size()) {
2190     status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2191                                      options->opts.return_tensors.size(),
2192                                      ", got ", num_return_outputs);
2193     return;
2194   }
2195   if (num_return_outputs > 0 && return_outputs == nullptr) {
2196     status->status = InvalidArgument(
2197         "'return_outputs' must be preallocated to length ", num_return_outputs);
2198     return;
2199   }
2200   GraphDef def;
2201   if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
2202                                        graph_def->length)) {
2203     status->status = InvalidArgument("Invalid GraphDef");
2204     return;
2205   }
2206   TF_ImportGraphDefResults results;
2207   mutex_lock l(graph->mu);
2208   GraphImportGraphDefLocked(graph, def, options, &results, status);
2209   DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2210   memcpy(return_outputs, results.return_tensors.data(),
2211          num_return_outputs * sizeof(TF_Output));
2212 }
2213 
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2214 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2215                             const TF_ImportGraphDefOptions* options,
2216                             TF_Status* status) {
2217   TF_ImportGraphDefResults* results =
2218       TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2219   TF_DeleteImportGraphDefResults(results);
2220 }
2221 
2222 // While loop functions -------------------------------------------------------
2223 
2224 namespace {
2225 
2226 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2227 
2228 // Creates a placeholder representing an input to the cond or body graph.
2229 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)2230 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2231                  TF_Output* input, TF_Status* status) {
2232   TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2233   TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2234   // TODO(skyewm): set placeholder shape
2235   TF_Operation* oper = TF_FinishOperation(desc, status);
2236   if (TF_GetCode(status) != TF_OK) return false;
2237   *input = {oper, 0};
2238   return true;
2239 }
2240 
2241 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2242 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
2243 // will be prepended to copied node names. `control_deps` are nodes in
2244 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
2245 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2246 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)2247 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2248                  tensorflow::ShapeRefiner* dst_refiner,
2249                  const TF_Output* src_inputs,
2250                  const std::vector<tensorflow::Output>& dst_inputs,
2251                  const string& prefix,
2252                  const std::vector<tensorflow::Operation>& control_deps,
2253                  const TF_Output* nodes_to_return, int nreturn_nodes,
2254                  std::vector<tensorflow::Output>* return_nodes) {
2255   DCHECK(return_nodes != nullptr);
2256   GraphDef gdef;
2257   src_graph->ToGraphDef(&gdef);
2258 
2259   tensorflow::ImportGraphDefOptions opts;
2260   opts.prefix = prefix;
2261 
2262   for (int i = 0; i < dst_inputs.size(); ++i) {
2263     opts.input_map[ToTensorId(src_inputs[i])] =
2264         TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2265   }
2266   opts.skip_mapped_nodes = true;
2267 
2268   for (const tensorflow::Operation& op : control_deps) {
2269     opts.control_dependencies.push_back(op.node()->name());
2270   }
2271 
2272   for (int i = 0; i < nreturn_nodes; ++i) {
2273     opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2274   }
2275 
2276   // TODO(skyewm): change to OutputTensor
2277   tensorflow::ImportGraphDefResults results;
2278   TF_RETURN_IF_ERROR(
2279       ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2280 
2281   for (const auto& pair : results.return_tensors) {
2282     return_nodes->emplace_back(pair.first, pair.second);
2283   }
2284   return Status::OK();
2285 }
2286 
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)2287 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2288   if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2289       params.cond_graph->parent == nullptr ||
2290       params.cond_graph->parent != params.body_graph->parent ||
2291       params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2292       params.ninputs <= 0 || params.cond_inputs == nullptr ||
2293       params.body_inputs == nullptr || params.body_outputs == nullptr) {
2294     s->status = InvalidArgument(
2295         "TF_WhileParams must be created by successful TF_NewWhile() call");
2296     return false;
2297   }
2298   return true;
2299 }
2300 
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)2301 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2302   if (params.cond_output.oper == nullptr) {
2303     s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2304     return false;
2305   }
2306   for (int i = 0; i < params.ninputs; ++i) {
2307     if (params.body_outputs[i].oper == nullptr) {
2308       s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2309                                   "field isn't set");
2310       return false;
2311     }
2312   }
2313   if (params.name == nullptr) {
2314     s->status = InvalidArgument("TF_WhileParams `name` field is null");
2315     return false;
2316   }
2317   return true;
2318 }
2319 
2320 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2321 
FreeWhileResources(const TF_WhileParams * params)2322 void FreeWhileResources(const TF_WhileParams* params) {
2323   TF_DeleteGraph(params->cond_graph);
2324   TF_DeleteGraph(params->body_graph);
2325   delete[] params->cond_inputs;
2326   delete[] params->body_inputs;
2327   delete[] params->body_outputs;
2328 }
2329 
EmptyWhileParams()2330 TF_WhileParams EmptyWhileParams() {
2331   return {0,       nullptr, nullptr, {nullptr, 0},
2332           nullptr, nullptr, nullptr, nullptr};
2333 }
2334 
2335 }  // namespace
2336 
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)2337 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2338                            TF_Status* status) {
2339 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2340   status->status = tensorflow::errors::Unimplemented(
2341       "Creating while loops is not supported on mobile. File a bug at "
2342       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2343       "important to you");
2344   return EmptyWhileParams();
2345 #else
2346   if (ninputs == 0) {
2347     status->status =
2348         InvalidArgument("TF_NewWhile() must be passed at least one input");
2349     return EmptyWhileParams();
2350   }
2351 
2352   TF_Graph* cond_graph = TF_NewGraph();
2353   TF_Graph* body_graph = TF_NewGraph();
2354   cond_graph->parent = g;
2355   cond_graph->parent_inputs = inputs;
2356   body_graph->parent = g;
2357   body_graph->parent_inputs = inputs;
2358 
2359   TF_Output* cond_inputs = new TF_Output[ninputs];
2360   TF_Output cond_output = {nullptr, -1};
2361   TF_Output* body_inputs = new TF_Output[ninputs];
2362   TF_Output* body_outputs = new TF_Output[ninputs];
2363   for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2364   const char* name = nullptr;
2365 
2366   for (int i = 0; i < ninputs; ++i) {
2367     // TODO(skyewm): prefix names with underscore (requires some plumbing)
2368     if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2369                      &cond_inputs[i], status)) {
2370       break;
2371     }
2372     if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2373                      &body_inputs[i], status)) {
2374       break;
2375     }
2376   }
2377 
2378   TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
2379                            body_graph, body_inputs, body_outputs, name};
2380 
2381   if (TF_GetCode(status) != TF_OK) {
2382     FreeWhileResources(&params);
2383     return EmptyWhileParams();
2384   }
2385   return params;
2386 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2387 }
2388 
2389 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2390 namespace {
2391 
2392 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2393 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2394                           TF_Output* outputs) {
2395   if (!ValidateInputWhileParams(*params, status)) return;
2396 
2397   TF_Graph* parent = params->cond_graph->parent;
2398   TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2399   int num_loop_vars = params->ninputs;
2400 
2401   mutex_lock l(parent->mu);
2402 
2403   // 'cond_fn' copies the cond graph into the parent graph.
2404   tensorflow::ops::CondGraphBuilderFn cond_fn =
2405       [params, parent](const tensorflow::Scope& scope,
2406                        const std::vector<tensorflow::Output>& inputs,
2407                        tensorflow::Output* output) {
2408         DCHECK_EQ(scope.graph(), &parent->graph);
2409         std::vector<tensorflow::Output> cond_output;
2410         TF_RETURN_IF_ERROR(CopyGraph(
2411             &params->cond_graph->graph, &parent->graph, &parent->refiner,
2412             params->cond_inputs, inputs, scope.impl()->name(),
2413             scope.impl()->control_deps(), &params->cond_output,
2414             /* nreturn_nodes */ 1, &cond_output));
2415         *output = cond_output[0];
2416         return Status::OK();
2417       };
2418 
2419   // 'body_fn' copies the body graph into the parent graph.
2420   tensorflow::ops::BodyGraphBuilderFn body_fn =
2421       [params, parent, num_loop_vars](
2422           const tensorflow::Scope& scope,
2423           const std::vector<tensorflow::Output>& inputs,
2424           std::vector<tensorflow::Output>* outputs) {
2425         DCHECK_EQ(scope.graph(), &parent->graph);
2426         TF_RETURN_IF_ERROR(
2427             CopyGraph(&params->body_graph->graph, &parent->graph,
2428                       &parent->refiner, params->body_inputs, inputs,
2429                       scope.impl()->name(), scope.impl()->control_deps(),
2430                       params->body_outputs, num_loop_vars, outputs));
2431         return Status::OK();
2432       };
2433 
2434   // Create the while loop using an internal scope.
2435   tensorflow::Scope scope =
2436       NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2437           .NewSubScope(params->name);
2438 
2439   const int first_new_node_id = parent->graph.num_node_ids();
2440 
2441   tensorflow::OutputList loop_outputs;
2442   status->status = tensorflow::ops::BuildWhileLoop(
2443       scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2444       body_fn, params->name, &loop_outputs);
2445 
2446   // Update name_map with newly-created ops.
2447   // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2448   // a bad status. Once we fix this, we may want to return early instead of
2449   // executing the following code.
2450   for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2451     Node* new_node = parent->graph.FindNodeId(i);
2452     if (new_node == nullptr) continue;
2453     parent->name_map[new_node->name()] = new_node;
2454   }
2455 
2456   // Populate 'outputs'.
2457   DCHECK_LE(loop_outputs.size(), num_loop_vars);
2458   for (int i = 0; i < loop_outputs.size(); ++i) {
2459     outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2460   }
2461 }
2462 
2463 }  // namespace
2464 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2465 
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2466 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2467                     TF_Output* outputs) {
2468 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2469   status->status = tensorflow::errors::Unimplemented(
2470       "Creating while loops is not supported on mobile. File a bug at "
2471       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2472       "important to you");
2473 #else
2474   // If it appears the caller created or modified `params`, don't free resources
2475   if (!ValidateConstWhileParams(*params, status)) return;
2476   TF_FinishWhileHelper(params, status, outputs);
2477   FreeWhileResources(params);
2478 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2479 }
2480 
TF_AbortWhile(const TF_WhileParams * params)2481 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2482 
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2483 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2484                      TF_Output* dx, TF_Status* status, TF_Output* dy) {
2485   TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2486 }
2487 
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2488 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2489                                int ny, TF_Output* x, int nx, TF_Output* dx,
2490                                TF_Status* status, TF_Output* dy) {
2491 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2492   status->status = tensorflow::errors::Unimplemented(
2493       "Adding gradients is not supported on mobile. File a bug at "
2494       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2495       "important to you");
2496 #else
2497   std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2498   std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2499   std::vector<tensorflow::Output> dy_arg;
2500 
2501   {
2502     // We need to hold on to the lock while we have a scope that uses TF_Graph.
2503     mutex_lock graph_lock(g->mu);
2504 
2505     const int first_new_node_id = g->graph.num_node_ids();
2506 
2507     string prefix_cmp;
2508     const char* child_scope_name;
2509     if (prefix == nullptr) {
2510       child_scope_name = "gradients";
2511     } else {
2512       prefix_cmp = string(prefix) + "/";
2513       // The operation should fail if the provided name prefix has already been
2514       // used in this graph
2515       for (const auto& pair : g->name_map) {
2516         const string& name = pair.first;
2517         if (name.compare(prefix) == 0 ||
2518             tensorflow::str_util::StartsWith(name, prefix_cmp)) {
2519           status->status = InvalidArgument(
2520               "prefix [", prefix,
2521               "] conflicts with existing node in the graph named [", name, "]");
2522           return;
2523         }
2524       }
2525       child_scope_name = prefix;
2526     }
2527     tensorflow::Scope scope =
2528         NewInternalScope(&g->graph, &status->status, &g->refiner)
2529             .NewSubScope(child_scope_name);
2530 
2531     if (dx != nullptr) {
2532       std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2533       status->status =
2534           AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2535     } else {
2536       status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2537     }
2538 
2539     // Update g->name_map with the name_map from the scope, which will contain
2540     // the new gradient ops.
2541     for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2542       Node* n = g->graph.FindNodeId(i);
2543       if (n == nullptr) continue;
2544 
2545       // Adding the gradients to the graph can alter the prefix to prevent
2546       // name collisions only if this prefix has not been provided explicitly
2547       // by the user. If it was provided, assert that it remained intact.
2548       if (prefix != nullptr &&
2549           !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
2550         status->status = tensorflow::errors::Internal(
2551             "BUG: The gradients prefix have been unexpectedly altered when "
2552             "adding the nodes to the graph. This is a bug. Please file an "
2553             "issue at https://github.com/tensorflow/tensorflow/issues.");
2554         return;
2555       }
2556       // We have a convoluted scheme here: Using the C++ graph construction API
2557       // to add potentially many nodes to the graph without running the checks
2558       // (such as uniqueness of the names of nodes) we run with other functions
2559       // that add a node to the graph (like TF_FinishOperation).
2560       if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2561         status->status = tensorflow::errors::Internal(
2562             "BUG: The API allowed construction of a graph with duplicate node "
2563             "names (",
2564             n->name(),
2565             "). This is a bug. Please file an issue at "
2566             "https://github.com/tensorflow/tensorflow/issues.");
2567       }
2568     }
2569   }
2570 
2571   // Unpack the results from grad_outputs_arg.
2572   TFOutputsFromOutputs(dy_arg, dy);
2573 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2574 }
2575 
2576 // TF_Session functions ----------------------------------------------
2577 
TF_Session(tensorflow::Session * s,TF_Graph * g)2578 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2579     : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2580 
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2581 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2582                           TF_Status* status) {
2583   Session* session;
2584   status->status = NewSession(opt->options, &session);
2585   if (TF_GetCode(status) == TF_OK) {
2586     TF_Session* new_session = new TF_Session(session, graph);
2587     if (graph != nullptr) {
2588       mutex_lock l(graph->mu);
2589       graph->sessions[new_session] = "";
2590     }
2591     return new_session;
2592   } else {
2593     DCHECK_EQ(nullptr, session);
2594     return nullptr;
2595   }
2596 }
2597 
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2598 TF_Session* TF_LoadSessionFromSavedModel(
2599     const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2600     const char* export_dir, const char* const* tags, int tags_len,
2601     TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2602 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2603 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2604 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2605   status->status = tensorflow::errors::Unimplemented(
2606       "Loading a SavedModel is not supported on mobile. File a bug at "
2607       "https://github.com/tensorflow/tensorflow/issues if this feature is "
2608       "important to you");
2609   return nullptr;
2610 #else
2611   mutex_lock l(graph->mu);
2612   if (!graph->name_map.empty()) {
2613     status->status = InvalidArgument("Graph is non-empty.");
2614     return nullptr;
2615   }
2616 
2617   RunOptions run_options_proto;
2618   if (run_options != nullptr && !run_options_proto.ParseFromArray(
2619                                     run_options->data, run_options->length)) {
2620     status->status = InvalidArgument("Unparseable RunOptions proto");
2621     return nullptr;
2622   }
2623 
2624   std::unordered_set<string> tag_set;
2625   for (int i = 0; i < tags_len; i++) {
2626     tag_set.insert(string(tags[i]));
2627   }
2628 
2629   tensorflow::SavedModelBundle bundle;
2630   status->status =
2631       tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2632                                  export_dir, tag_set, &bundle);
2633   if (TF_GetCode(status) != TF_OK) return nullptr;
2634 
2635   // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2636   // extends using GraphDefs. The Graph instance is different, but equivalent
2637   // to the one used to create the session.
2638   //
2639   // TODO(jhseu): When Session is modified to take Graphs instead of
2640   // GraphDefs, return the Graph generated in LoadSavedModel().
2641   TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2642   TF_ImportGraphDefResults results;
2643   GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2644                             import_opts, &results, status);
2645   TF_DeleteImportGraphDefOptions(import_opts);
2646   if (TF_GetCode(status) != TF_OK) return nullptr;
2647 
2648   if (meta_graph_def != nullptr) {
2649     status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2650     if (TF_GetCode(status) != TF_OK) return nullptr;
2651   }
2652 
2653   TF_Session* session = new TF_Session(bundle.session.release(), graph);
2654 
2655   graph->sessions[session] = "";
2656   session->last_num_graph_nodes = graph->graph.num_node_ids();
2657   return session;
2658 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2659 }
2660 
TF_CloseSession(TF_Session * s,TF_Status * status)2661 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2662   status->status = s->session->Close();
2663 }
2664 
TF_DeleteSession(TF_Session * s,TF_Status * status)2665 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2666   status->status = Status::OK();
2667   if (s == nullptr) return;
2668   TF_Graph* const graph = s->graph;
2669   if (graph != nullptr) {
2670     graph->mu.lock();
2671     graph->sessions.erase(s);
2672     const bool del = graph->delete_requested && graph->sessions.empty();
2673     graph->mu.unlock();
2674     if (del) delete graph;
2675   }
2676   delete s->session;
2677   delete s;
2678 }
2679 
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2680 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2681                    const TF_Output* inputs, TF_Tensor* const* input_values,
2682                    int ninputs, const TF_Output* outputs,
2683                    TF_Tensor** output_values, int noutputs,
2684                    const TF_Operation* const* target_opers, int ntargets,
2685                    TF_Buffer* run_metadata, TF_Status* status) {
2686   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2687   // directly, instead of requiring us to serialize to a GraphDef and
2688   // call Session::Extend().
2689   if (session->extend_before_run &&
2690       !ExtendSessionGraphHelper(session, status)) {
2691     return;
2692   }
2693 
2694   TF_Run_Setup(noutputs, output_values, status);
2695 
2696   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2697   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2698   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2699   for (int i = 0; i < ninputs; ++i) {
2700     input_pairs[i].first = OutputName(inputs[i]);
2701   }
2702 
2703   // Convert from TF_Output to string names.
2704   std::vector<string> output_names(noutputs);
2705   for (int i = 0; i < noutputs; ++i) {
2706     output_names[i] = OutputName(outputs[i]);
2707   }
2708 
2709   // Convert from TF_Operation* to string names.
2710   std::vector<string> target_names(ntargets);
2711   for (int i = 0; i < ntargets; ++i) {
2712     target_names[i] = target_opers[i]->node.name();
2713   }
2714 
2715   // Actually run.
2716   TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2717                 output_names, output_values, target_names, run_metadata,
2718                 status);
2719 }
2720 
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2721 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2722                          int ninputs, const TF_Output* outputs, int noutputs,
2723                          const TF_Operation* const* target_opers, int ntargets,
2724                          const char** handle, TF_Status* status) {
2725   *handle = nullptr;
2726 
2727   if (session->extend_before_run &&
2728       !ExtendSessionGraphHelper(session, status)) {
2729     return;
2730   }
2731 
2732   std::vector<string> input_names(ninputs);
2733   for (int i = 0; i < ninputs; ++i) {
2734     input_names[i] = OutputName(inputs[i]);
2735   }
2736 
2737   std::vector<string> output_names(noutputs);
2738   for (int i = 0; i < noutputs; ++i) {
2739     output_names[i] = OutputName(outputs[i]);
2740   }
2741 
2742   std::vector<string> target_names(ntargets);
2743   for (int i = 0; i < ntargets; ++i) {
2744     target_names[i] = target_opers[i]->node.name();
2745   }
2746 
2747   string new_handle;
2748   status->status = session->session->PRunSetup(input_names, output_names,
2749                                                target_names, &new_handle);
2750   if (TF_GetCode(status) == TF_OK) {
2751     char* buf = new char[new_handle.size() + 1];
2752     memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2753     *handle = buf;
2754   }
2755 }
2756 
TF_DeletePRunHandle(const char * handle)2757 void TF_DeletePRunHandle(const char* handle) {
2758   delete[] handle;
2759   // TODO(suharshs): Free up any resources held by the partial run state.
2760 }
2761 
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2762 void TF_SessionPRun(TF_Session* session, const char* handle,
2763                     const TF_Output* inputs, TF_Tensor* const* input_values,
2764                     int ninputs, const TF_Output* outputs,
2765                     TF_Tensor** output_values, int noutputs,
2766                     const TF_Operation* const* target_opers, int ntargets,
2767                     TF_Status* status) {
2768   // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2769   // directly, instead of requiring us to serialize to a GraphDef and
2770   // call Session::Extend().
2771   if (session->extend_before_run &&
2772       !ExtendSessionGraphHelper(session, status)) {
2773     return;
2774   }
2775 
2776   TF_Run_Setup(noutputs, output_values, status);
2777 
2778   // Convert from TF_Output and TF_Tensor to a string and Tensor.
2779   std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2780   if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2781   for (int i = 0; i < ninputs; ++i) {
2782     input_pairs[i].first = OutputName(inputs[i]);
2783   }
2784 
2785   // Convert from TF_Output to string names.
2786   std::vector<string> output_names(noutputs);
2787   for (int i = 0; i < noutputs; ++i) {
2788     output_names[i] = OutputName(outputs[i]);
2789   }
2790 
2791   // Convert from TF_Operation* to string names.
2792   std::vector<string> target_names(ntargets);
2793   for (int i = 0; i < ntargets; ++i) {
2794     target_names[i] = target_opers[i]->node.name();
2795   }
2796 
2797   TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2798                 output_values, target_names, nullptr, status);
2799 }
2800 
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2801 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2802                                      TF_Tensor** result, TF_Status* status) {
2803   *result = nullptr;
2804   mutex_lock l(graph->mu);
2805   OutputTensor tensor(&output.oper->node, output.index);
2806   bool evaluated;
2807   Tensor result_tensor;
2808   status->status = EvaluateConstantTensor(
2809       tensor, graph->refiner, *graph->graph.op_registry(),
2810       graph->graph.versions().producer(), &evaluated, &result_tensor);
2811   if (evaluated) {
2812     DCHECK(TF_GetCode(status) == TF_OK);
2813     *result = TF_TensorFromTensor(result_tensor, status);
2814     if (TF_GetCode(status) != TF_OK) evaluated = false;
2815   }
2816   return evaluated;
2817 }
2818 
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2819 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2820   tensorflow::OpList op_list;
2821   if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2822     status->status = InvalidArgument("Unparseable OpList");
2823     return nullptr;
2824   }
2825   status->status = Status::OK();
2826   return new TF_ApiDefMap(op_list);
2827 }
2828 
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2829 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2830 
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2831 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2832                      size_t text_len, TF_Status* status) {
2833 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2834   status->status = tensorflow::errors::Unimplemented(
2835       "ApiDefMap is not supported on mobile.");
2836 #else
2837   mutex_lock l(api_def_map->lock);
2838   if (api_def_map->update_docs_called) {
2839     status->status = FailedPrecondition(
2840         "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2841         "called.");
2842     return;
2843   }
2844   string api_def_text(text, text_len);
2845   status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2846 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2847 }
2848 
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2849 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2850                            size_t name_len, TF_Status* status) {
2851 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2852   status->status = tensorflow::errors::Unimplemented(
2853       "ApiDefMap is not supported on mobile.");
2854   return nullptr;
2855 #else
2856   mutex_lock l(api_def_map->lock);
2857   if (!api_def_map->update_docs_called) {
2858     api_def_map->api_def_map.UpdateDocs();
2859     api_def_map->update_docs_called = true;
2860   }
2861   string name_str(name, name_len);
2862   const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2863   if (api_def == nullptr) {
2864     return nullptr;
2865   }
2866 
2867   TF_Buffer* ret = TF_NewBuffer();
2868   status->status = MessageToBuffer(*api_def, ret);
2869   if (TF_GetCode(status) != TF_OK) {
2870     TF_DeleteBuffer(ret);
2871     return nullptr;
2872   }
2873   return ret;
2874 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2875 }
2876 
TF_GetAllRegisteredKernels(TF_Status * status)2877 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2878   tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2879   TF_Buffer* ret = TF_NewBuffer();
2880   status->status = MessageToBuffer(kernel_list, ret);
2881   if (TF_GetCode(status) != TF_OK) {
2882     TF_DeleteBuffer(ret);
2883     return nullptr;
2884   }
2885   return ret;
2886 }
2887 
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2888 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2889   tensorflow::KernelList kernel_list =
2890       tensorflow::GetRegisteredKernelsForOp(name);
2891   TF_Buffer* ret = TF_NewBuffer();
2892   status->status = MessageToBuffer(kernel_list, ret);
2893   if (TF_GetCode(status) != TF_OK) {
2894     TF_DeleteBuffer(ret);
2895     return nullptr;
2896   }
2897   return ret;
2898 }
2899 
2900 // TF_Server functions ----------------------------------------------
2901 
2902 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2903 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2904     : target(server->target()), server(std::move(server)) {}
2905 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2906 
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2907 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2908                         TF_Status* status) {
2909 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2910   status->status = tensorflow::errors::Unimplemented(
2911       "Server functionality is not supported on mobile");
2912   return nullptr;
2913 #else
2914   tensorflow::ServerDef server_def;
2915   if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2916     status->status = InvalidArgument(
2917         "Could not parse provided bytes into a ServerDef protocol buffer");
2918     return nullptr;
2919   }
2920 
2921   std::unique_ptr<tensorflow::ServerInterface> out_server;
2922   status->status = tensorflow::NewServer(server_def, &out_server);
2923   if (TF_GetCode(status) != TF_OK) return nullptr;
2924 
2925   return new TF_Server(std::move(out_server));
2926 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2927 }
2928 
TF_ServerStart(TF_Server * server,TF_Status * status)2929 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2930 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2931   status->status = tensorflow::errors::Unimplemented(
2932       "Server functionality is not supported on mobile");
2933 #else
2934   status->status = server->server->Start();
2935 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2936 }
2937 
TF_ServerStop(TF_Server * server,TF_Status * status)2938 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2939 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2940   status->status = tensorflow::errors::Unimplemented(
2941       "Server functionality is not supported on mobile");
2942 #else
2943   status->status = server->server->Stop();
2944 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2945 }
2946 
TF_ServerJoin(TF_Server * server,TF_Status * status)2947 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2948 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2949   status->status = tensorflow::errors::Unimplemented(
2950       "Server functionality is not supported on mobile");
2951 #else
2952   status->status = server->server->Join();
2953 #endif  // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2954 }
2955 
TF_ServerTarget(TF_Server * server)2956 const char* TF_ServerTarget(TF_Server* server) {
2957 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2958   return nullptr;
2959 #else
2960   return server->target.c_str();
2961 #endif
2962 }
2963 
TF_DeleteServer(TF_Server * server)2964 void TF_DeleteServer(TF_Server* server) {
2965 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2966   delete server;
2967 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2968 }
2969 
TF_RegisterLogListener(void (* listener)(const char *))2970 void TF_RegisterLogListener(void (*listener)(const char*)) {
2971 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2972   tensorflow::logging::RegisterListener(listener);
2973 #endif  // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2974 }
2975 
2976 }  // end extern "C"
2977