• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1  /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2  
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6  
7      http://www.apache.org/licenses/LICENSE-2.0
8  
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15  
16  #include "tensorflow/c/c_api.h"
17  
18  #include <algorithm>
19  #include <limits>
20  #include <memory>
21  #include <vector>
22  
23  #ifndef __ANDROID__
24  #include "tensorflow/cc/framework/gradients.h"
25  #include "tensorflow/cc/framework/ops.h"
26  #include "tensorflow/cc/framework/scope_internal.h"
27  #include "tensorflow/cc/ops/while_loop.h"
28  #include "tensorflow/cc/saved_model/loader.h"
29  #include "tensorflow/core/framework/op_gen_lib.h"
30  #endif
31  #include "tensorflow/c/c_api_internal.h"
32  #include "tensorflow/core/common_runtime/device_mgr.h"
33  #include "tensorflow/core/common_runtime/shape_refiner.h"
34  #include "tensorflow/core/framework/allocation_description.pb.h"
35  #include "tensorflow/core/framework/log_memory.h"
36  #include "tensorflow/core/framework/node_def_util.h"
37  #include "tensorflow/core/framework/op_kernel.h"
38  #include "tensorflow/core/framework/partial_tensor_shape.h"
39  #include "tensorflow/core/framework/tensor.h"
40  #include "tensorflow/core/framework/tensor_shape.h"
41  #include "tensorflow/core/framework/tensor_shape.pb.h"
42  #include "tensorflow/core/framework/types.h"
43  #include "tensorflow/core/framework/versions.pb.h"
44  #include "tensorflow/core/graph/graph.h"
45  #include "tensorflow/core/graph/graph_constructor.h"
46  #include "tensorflow/core/graph/node_builder.h"
47  #include "tensorflow/core/lib/core/coding.h"
48  #include "tensorflow/core/lib/core/errors.h"
49  #include "tensorflow/core/lib/core/status.h"
50  #include "tensorflow/core/lib/core/stringpiece.h"
51  #include "tensorflow/core/lib/gtl/array_slice.h"
52  #include "tensorflow/core/lib/strings/strcat.h"
53  #include "tensorflow/core/platform/mem.h"
54  #include "tensorflow/core/platform/mutex.h"
55  #include "tensorflow/core/platform/protobuf.h"
56  #include "tensorflow/core/platform/thread_annotations.h"
57  #include "tensorflow/core/platform/types.h"
58  #include "tensorflow/core/public/session.h"
59  #include "tensorflow/core/public/version.h"
60  
61  // The implementation below is at the top level instead of the
62  // brain namespace because we are defining 'extern "C"' functions.
63  using tensorflow::AllocationDescription;
64  using tensorflow::DataType;
65  using tensorflow::Graph;
66  using tensorflow::GraphDef;
67  using tensorflow::mutex_lock;
68  using tensorflow::NameRangeMap;
69  using tensorflow::NameRangesForNode;
70  using tensorflow::NewSession;
71  using tensorflow::Node;
72  using tensorflow::NodeBuilder;
73  using tensorflow::NodeDef;
74  using tensorflow::OpDef;
75  using tensorflow::OpRegistry;
76  using tensorflow::PartialTensorShape;
77  using tensorflow::RunMetadata;
78  using tensorflow::RunOptions;
79  using tensorflow::Session;
80  using tensorflow::Status;
81  using tensorflow::string;
82  using tensorflow::Tensor;
83  using tensorflow::TensorBuffer;
84  using tensorflow::TensorId;
85  using tensorflow::TensorShape;
86  using tensorflow::TensorShapeProto;
87  using tensorflow::VersionDef;
88  using tensorflow::error::Code;
89  using tensorflow::errors::FailedPrecondition;
90  using tensorflow::errors::InvalidArgument;
91  using tensorflow::gtl::ArraySlice;
92  using tensorflow::strings::StrCat;
93  
94  extern "C" {
95  
96  // --------------------------------------------------------------------------
TF_Version()97  const char* TF_Version() { return TF_VERSION_STRING; }
98  
99  // --------------------------------------------------------------------------
TF_DataTypeSize(TF_DataType dt)100  size_t TF_DataTypeSize(TF_DataType dt) {
101    return static_cast<size_t>(
102        tensorflow::DataTypeSize(static_cast<DataType>(dt)));
103  }
104  
105  // --------------------------------------------------------------------------
106  
TF_NewStatus()107  TF_Status* TF_NewStatus() { return new TF_Status; }
108  
TF_DeleteStatus(TF_Status * s)109  void TF_DeleteStatus(TF_Status* s) { delete s; }
110  
TF_SetStatus(TF_Status * s,TF_Code code,const char * msg)111  void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
112    if (code == TF_OK) {
113      s->status = Status::OK();
114      return;
115    }
116    s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
117  }
118  
TF_GetCode(const TF_Status * s)119  TF_Code TF_GetCode(const TF_Status* s) {
120    return static_cast<TF_Code>(s->status.code());
121  }
122  
TF_Message(const TF_Status * s)123  const char* TF_Message(const TF_Status* s) {
124    return s->status.error_message().c_str();
125  }
126  
127  // --------------------------------------------------------------------------
128  
129  namespace {
130  class TF_ManagedBuffer : public TensorBuffer {
131   public:
132    void* data_;
133    size_t len_;
134    void (*deallocator_)(void* data, size_t len, void* arg);
135    void* deallocator_arg_;
136  
~TF_ManagedBuffer()137    ~TF_ManagedBuffer() override {
138      (*deallocator_)(data_, len_, deallocator_arg_);
139    }
140  
data() const141    void* data() const override { return data_; }
size() const142    size_t size() const override { return len_; }
root_buffer()143    TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const144    void FillAllocationDescription(AllocationDescription* proto) const override {
145      tensorflow::int64 rb = size();
146      proto->set_requested_bytes(rb);
147      proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
148    }
149  
150    // Prevents input forwarding from mutating this buffer.
OwnsMemory() const151    bool OwnsMemory() const override { return false; }
152  };
153  
allocate_tensor(const char * operation,size_t len)154  void* allocate_tensor(const char* operation, size_t len) {
155    void* data =
156        tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
157    if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
158      tensorflow::LogMemory::RecordRawAllocation(
159          operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
160          len, data, tensorflow::cpu_allocator());
161    }
162    return data;
163  }
164  
deallocate_buffer(void * data,size_t len,void * arg)165  void deallocate_buffer(void* data, size_t len, void* arg) {
166    if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
167      tensorflow::LogMemory::RecordRawDeallocation(
168          "TensorFlow C Api",
169          tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
170          tensorflow::cpu_allocator(), false);
171    }
172    tensorflow::cpu_allocator()->DeallocateRaw(data);
173  }
174  
175  }  // namespace
176  
~TF_Tensor()177  TF_Tensor::~TF_Tensor() { buffer->Unref(); }
178  
TF_AllocateTensor(TF_DataType dtype,const int64_t * dims,int num_dims,size_t len)179  TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
180                               int num_dims, size_t len) {
181    void* data = allocate_tensor("TF_AllocateTensor", len);
182    return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
183                        nullptr);
184  }
185  
TF_NewTensor(TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)186  TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
187                          void* data, size_t len,
188                          void (*deallocator)(void* data, size_t len, void* arg),
189                          void* deallocator_arg) {
190    std::vector<tensorflow::int64> dimvec(num_dims);
191    for (int i = 0; i < num_dims; ++i) {
192      dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
193    }
194  
195    TF_ManagedBuffer* buf = new TF_ManagedBuffer;
196    buf->len_ = len;
197    if (dtype != TF_STRING && dtype != TF_RESOURCE &&
198        tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
199        reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
200      // TF_STRING and TF_RESOURCE tensors have a different representation in
201      // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
202      // (any alignment requirements will be taken care of by TF_TensorToTensor
203      // and TF_TensorFromTensor).
204      //
205      // Other types have the same representation, so copy only if it is safe to
206      // do so.
207      buf->data_ = allocate_tensor("TF_NewTensor", len);
208      std::memcpy(buf->data_, data, len);
209      buf->deallocator_ = deallocate_buffer;
210      buf->deallocator_arg_ = nullptr;
211      // Free the original buffer.
212      deallocator(data, len, deallocator_arg);
213    } else {
214      buf->data_ = data;
215      buf->deallocator_ = deallocator;
216      buf->deallocator_arg_ = deallocator_arg;
217    }
218    TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
219    size_t elem_size = TF_DataTypeSize(dtype);
220    if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
221      delete ret;
222      return nullptr;
223    }
224    return ret;
225  }
226  
TF_TensorMaybeMove(TF_Tensor * tensor)227  TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
228    // It is safe to move the Tensor if and only if we own the unique reference to
229    // it. In that case, we might as well not delete and reallocate, but a future
230    // implementation might need to do so.
231    TensorBuffer* buf = tensor->buffer;
232    if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
233        buf->OwnsMemory()) {
234      return tensor;
235    }
236    return nullptr;
237  }
238  
TF_DeleteTensor(TF_Tensor * t)239  void TF_DeleteTensor(TF_Tensor* t) { delete t; }
240  
TF_TensorType(const TF_Tensor * t)241  TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
TF_NumDims(const TF_Tensor * t)242  int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
TF_Dim(const TF_Tensor * t,int dim_index)243  int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
244    return static_cast<int64_t>(t->shape.dim_size(dim_index));
245  }
TF_TensorByteSize(const TF_Tensor * t)246  size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
TF_TensorData(const TF_Tensor * t)247  void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
248  
249  // --------------------------------------------------------------------------
TF_StringEncode(const char * src,size_t src_len,char * dst,size_t dst_len,TF_Status * status)250  size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
251                         size_t dst_len, TF_Status* status) {
252    const size_t sz = TF_StringEncodedSize(src_len);
253    if (sz < src_len) {
254      status->status = InvalidArgument("src string is too large to encode");
255      return 0;
256    }
257    if (dst_len < sz) {
258      status->status =
259          InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
260                          src_len, "-byte string");
261      return 0;
262    }
263    dst = tensorflow::core::EncodeVarint64(dst, src_len);
264    memcpy(dst, src, src_len);
265    return sz;
266  }
267  
TF_StringDecode_Impl(const char * src,size_t src_len,const char ** dst,size_t * dst_len)268  static Status TF_StringDecode_Impl(const char* src, size_t src_len,
269                                     const char** dst, size_t* dst_len) {
270    tensorflow::uint64 len64 = 0;
271    const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
272    if (p == nullptr) {
273      return InvalidArgument("invalid string encoding or truncated src buffer");
274    }
275    if (len64 > std::numeric_limits<size_t>::max()) {
276      return InvalidArgument("encoded string is ", len64,
277                             "-bytes, which is too large for this architecture");
278    }
279    *dst = p;
280    *dst_len = static_cast<size_t>(len64);
281    return Status::OK();
282  }
283  
TF_StringDecode(const char * src,size_t src_len,const char ** dst,size_t * dst_len,TF_Status * status)284  size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
285                         size_t* dst_len, TF_Status* status) {
286    status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
287    if (!status->status.ok()) return 0;
288    return static_cast<size_t>(*dst - src) + *dst_len;
289  }
290  
TF_StringEncodedSize(size_t len)291  size_t TF_StringEncodedSize(size_t len) {
292    return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
293  }
294  
295  // --------------------------------------------------------------------------
TF_NewSessionOptions()296  TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)297  void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
298  
TF_SetTarget(TF_SessionOptions * options,const char * target)299  void TF_SetTarget(TF_SessionOptions* options, const char* target) {
300    options->options.target = target;
301  }
302  
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)303  void TF_SetConfig(TF_SessionOptions* options, const void* proto,
304                    size_t proto_len, TF_Status* status) {
305    if (!options->options.config.ParseFromArray(proto, proto_len)) {
306      status->status = InvalidArgument("Unparseable ConfigProto");
307    }
308  }
309  // --------------------------------------------------------------------------
TF_NewBuffer()310  TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
311  
TF_NewBufferFromString(const void * proto,size_t proto_len)312  TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
313    void* copy = tensorflow::port::Malloc(proto_len);
314    memcpy(copy, proto, proto_len);
315  
316    TF_Buffer* buf = new TF_Buffer;
317    buf->data = copy;
318    buf->length = proto_len;
319    buf->data_deallocator = [](void* data, size_t length) {
320      tensorflow::port::Free(data);
321    };
322    return buf;
323  }
324  
TF_DeleteBuffer(TF_Buffer * buffer)325  void TF_DeleteBuffer(TF_Buffer* buffer) {
326    if (buffer->data_deallocator != nullptr) {
327      (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
328                                  buffer->length);
329    }
330    delete buffer;
331  }
332  
TF_GetBuffer(TF_Buffer * buffer)333  TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
334  
335  // --------------------------------------------------------------------------
336  
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)337  TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
338                                                TF_Status* status) {
339    Session* session;
340    status->status = NewSession(opt->options, &session);
341    if (status->status.ok()) {
342      return new TF_DeprecatedSession({session});
343    } else {
344      DCHECK_EQ(nullptr, session);
345      return nullptr;
346    }
347  }
348  
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)349  void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
350    status->status = s->session->Close();
351  }
352  
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)353  void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
354    status->status = Status::OK();
355    delete s->session;
356    delete s;
357  }
358  
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)359  void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
360                      size_t proto_len, TF_Status* status) {
361    GraphDef g;
362    if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
363      status->status = InvalidArgument("Invalid GraphDef");
364      return;
365    }
366    status->status = s->session->Extend(g);
367  }
368  
DeleteArray(void * data,size_t size,void * arg)369  static void DeleteArray(void* data, size_t size, void* arg) {
370    DCHECK_EQ(data, arg);
371    delete[] reinterpret_cast<char*>(arg);
372  }
373  
374  }  // end extern "C"
375  
376  namespace tensorflow {
377  namespace {
378  
379  // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)380  void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
381                       int ncontainers, TF_Status* status) {
382    std::vector<string> container_names(ncontainers);
383    for (int i = 0; i < ncontainers; ++i) {
384      container_names[i] = containers[i];
385    }
386  
387    status->status = Reset(opt->options, container_names);
388  }
389  
390  // This traverses the specified nodes in topological order to verify there are
391  // no cycles. Starting with inputless nodes, it visits nodes whose inputs have
392  // all been visited, and counts the total number of visited nodes. If there is a
393  // cycle, nodes in the cycle will never be visited, and the visited count will
394  // be less than the total node count.
ValidateNoCycles(const Graph & g)395  Status ValidateNoCycles(const Graph& g) {
396    // TODO(nolivia): check this on a subset of the graph instead of all of it.
397    // A node is ready when all of its inputs have been visited.
398    std::vector<const Node*> ready;
399    std::vector<int> pending_count(g.num_node_ids(), 0);
400  
401    for (int i = 0; i < g.num_node_ids(); ++i) {
402      const Node* n = g.FindNodeId(i);
403      if (n == nullptr) continue;
404      pending_count[i] = n->in_edges().size();
405      if (n->IsMerge()) {
406        // While-loop cycles are legal cycles so we manually adjust the
407        // pending_count to make sure that the loop is visited.
408        for (const Edge* e : n->in_edges()) {
409          if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
410            pending_count[i]--;
411          }
412        }
413      }
414      if (pending_count[i] == 0) {
415        ready.push_back(n);
416      }
417    }
418  
419    int processed = 0;
420    while (!ready.empty()) {
421      const Node* node = ready.back();
422      ready.pop_back();
423      ++processed;
424  
425      for (const Edge* out : node->out_edges()) {
426        const int output_id = out->dst()->id();
427        pending_count[output_id]--;
428        if (pending_count[output_id] == 0) {
429          ready.push_back(out->dst());
430        }
431      }
432    }
433  
434    if (processed < g.num_nodes()) {
435      std::vector<string> nodes_in_cycle;
436      for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
437           ++i) {
438        if (pending_count[i] != 0) {
439          nodes_in_cycle.push_back(g.FindNodeId(i)->name());
440        }
441      }
442      return errors::InvalidArgument(
443          "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
444          " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
445    }
446    return Status::OK();
447  }
448  }  // namespace
449  }  // namespace tensorflow
450  
451  extern "C" {
452  
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)453  void TF_Reset(const TF_SessionOptions* opt, const char** containers,
454                int ncontainers, TF_Status* status) {
455    tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
456  }
457  
458  }  // end extern "C"
459  
460  namespace tensorflow {
461  
TF_TensorToTensor(const TF_Tensor * src,Tensor * dst)462  Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
463    if (src->dtype == TF_RESOURCE) {
464      if (src->shape.dims() != 0) {
465        return InvalidArgument(
466            "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
467            "shape ",
468            src->shape.DebugString());
469      }
470      *dst = Tensor(DT_RESOURCE, src->shape);
471      if (!dst->scalar<ResourceHandle>()().ParseFromString(
472              string(static_cast<const char*>(TF_TensorData(src)),
473                     TF_TensorByteSize(src)))) {
474        return InvalidArgument(
475            "Malformed TF_RESOUCE tensor: unable to parse resource handle");
476      }
477      return Status::OK();
478    }
479    if (src->dtype != TF_STRING) {
480      *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
481      return Status::OK();
482    }
483    // TF_STRING tensors require copying since Tensor class expects a sequence of
484    // string objects.
485    const tensorflow::int64 num_elements = src->shape.num_elements();
486    const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
487    const size_t src_size = TF_TensorByteSize(src);
488    if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
489        num_elements) {
490      return InvalidArgument(
491          "Malformed TF_STRING tensor; too short to hold number of elements");
492    }
493    const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
494    const char* limit = input + src_size;
495  
496    *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
497    auto dstarray = dst->flat<string>();
498    for (tensorflow::int64 i = 0; i < num_elements; ++i) {
499      tensorflow::uint64 offset =
500          reinterpret_cast<const tensorflow::uint64*>(input)[i];
501      if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
502        return InvalidArgument("Malformed TF_STRING tensor; element ", i,
503                               " out of range");
504      }
505      size_t len;
506      const char* p;
507      const char* srcp = data_start + offset;
508      Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
509      if (!status.ok()) return status;
510      dstarray(i).assign(p, len);
511    }
512    return Status::OK();
513  }
514  
515  // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
516  // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const TensorShape & shape)517  static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
518    static char empty;
519    tensorflow::int64 nelems = 1;
520    std::vector<tensorflow::int64> dims;
521    for (int i = 0; i < shape.dims(); ++i) {
522      dims.push_back(shape.dim_size(i));
523      nelems *= shape.dim_size(i);
524    }
525    CHECK_EQ(nelems, 0);
526    static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
527                  "64-bit int types should match in size");
528    return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
529                        shape.dims(), reinterpret_cast<void*>(&empty), 0,
530                        [](void*, size_t, void*) {}, nullptr);
531  }
532  
533  // Non-static for testing.
TF_TensorFromTensor(const tensorflow::Tensor & src,TF_Status * status)534  TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
535                                 TF_Status* status) {
536    if (!src.IsInitialized()) {
537      status->status = FailedPrecondition(
538          "attempt to use a tensor with an uninitialized value");
539      return nullptr;
540    }
541    if (src.NumElements() == 0) {
542      return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
543    }
544    if (src.dtype() == DT_RESOURCE) {
545      if (src.shape().dims() != 0) {
546        status->status = InvalidArgument(
547            "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
548            src.shape().DebugString(),
549            "). Please file a bug at "
550            "https://github.com/tensorflow/tensorflow/issues/new, "
551            "ideally with a "
552            "short code snippet that reproduces this error.");
553        return nullptr;
554      }
555      const string str = src.scalar<ResourceHandle>()().SerializeAsString();
556      TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
557      std::memcpy(TF_TensorData(t), str.c_str(), str.size());
558      return t;
559    }
560    if (src.dtype() != DT_STRING) {
561      TensorBuffer* buf = TensorCApi::Buffer(src);
562      buf->Ref();
563      return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
564                           buf};
565    }
566    // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
567    // encoded sequence of strings.
568  
569    // Compute bytes needed for encoding.
570    size_t size = 0;
571    const auto& srcarray = src.flat<string>();
572    for (int i = 0; i < srcarray.size(); ++i) {
573      const string& s = srcarray(i);
574      // uint64 starting_offset, TF_StringEncode-d string.
575      size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
576    }
577  
578    // Encode all strings.
579    char* base = new char[size];
580    char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
581    char* dst = data_start;  // Where next string is encoded.
582    size_t dst_len = size - static_cast<size_t>(data_start - base);
583    tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
584    for (int i = 0; i < srcarray.size(); ++i) {
585      *offsets = (dst - data_start);
586      offsets++;
587      const string& s = srcarray(i);
588      size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
589      if (!status->status.ok()) {
590        status->status = InvalidArgument(
591            "invalid string tensor encoding (string #", i, " of ",
592            srcarray.size(), "): ", status->status.error_message());
593        delete[] base;
594        return nullptr;
595      }
596      dst += consumed;
597      dst_len -= consumed;
598    }
599    if (dst != base + size) {
600      status->status = InvalidArgument(
601          "invalid string tensor encoding (decoded ", (dst - base),
602          " bytes, but the tensor is encoded in ", size, " bytes");
603      delete[] base;
604      return nullptr;
605    }
606  
607    auto dims = src.shape().dim_sizes();
608    std::vector<tensorflow::int64> dimvec(dims.size());
609    for (size_t i = 0; i < dims.size(); ++i) {
610      dimvec[i] = dims[i];
611    }
612    static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
613                  "64-bit int types should match in size");
614    return TF_NewTensor(TF_STRING,
615                        reinterpret_cast<const int64_t*>(dimvec.data()),
616                        dimvec.size(), base, size, DeleteArray, base);
617  }
618  
MessageToBuffer(const tensorflow::protobuf::Message & in,TF_Buffer * out)619  Status MessageToBuffer(const tensorflow::protobuf::Message& in,
620                         TF_Buffer* out) {
621    if (out->data != nullptr) {
622      return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
623    }
624    const size_t proto_size = in.ByteSizeLong();
625    void* buf = tensorflow::port::Malloc(proto_size);
626    if (buf == nullptr) {
627      return tensorflow::errors::ResourceExhausted(
628          "Failed to allocate memory to serialize message of type '",
629          in.GetTypeName(), "' and size ", proto_size);
630    }
631    in.SerializeToArray(buf, proto_size);
632    out->data = buf;
633    out->length = proto_size;
634    out->data_deallocator = [](void* data, size_t length) {
635      tensorflow::port::Free(data);
636    };
637    return Status::OK();
638  }
639  
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)640  void RecordMutation(TF_Graph* graph, const TF_Operation& op,
641                      const char* mutation_type)
642      EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
643    // If any session has already run this node_id, mark this session as
644    // unrunnable.
645    for (auto it : graph->sessions) {
646      if (it.first->last_num_graph_nodes > op.node.id()) {
647        it.second = FailedPrecondition(
648            "Operation '", op.node.DebugString(), "' was changed by ",
649            mutation_type,
650            " after it was run by a session. Nodes can be mutated "
651            "only before they are executed by a session. Either don't modify "
652            "nodes after running them or create a new session.");
653      }
654    }
655  }
656  
657  namespace {
658  
659  // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)660  tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
661      tensorflow::shape_inference::InferenceContext* ic, int num_dims,
662      const int64_t* dims) {
663    if (num_dims != -1) {
664      std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
665      dim_vec.reserve(num_dims);
666      for (int i = 0; i < num_dims; ++i) {
667        dim_vec.push_back(ic->MakeDim(dims[i]));
668      }
669      return ic->MakeShape(dim_vec);
670    } else {
671      return ic->UnknownShape();
672    }
673  }
674  
675  }  // namespace
676  
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)677  void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
678                                             int num_shapes_and_types,
679                                             const int64_t** shapes,
680                                             const int* ranks,
681                                             const TF_DataType* types,
682                                             TF_Status* status) {
683    Node* node = &output.oper->node;
684  
685    mutex_lock l(graph->mu);
686    tensorflow::shape_inference::InferenceContext* ic =
687        graph->refiner.GetContext(node);
688    if (ic == nullptr) {
689      status->status =
690          InvalidArgument("Node ", node->name(), " was not found in the graph");
691      return;
692    }
693  
694    auto shape_and_type_vec =
695        std::vector<tensorflow::shape_inference::ShapeAndType>(
696            num_shapes_and_types);
697    for (int i = 0; i < num_shapes_and_types; ++i) {
698      tensorflow::shape_inference::ShapeHandle shape_handle =
699          ShapeHandleFromDims(ic, ranks[i], shapes[i]);
700      shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
701          shape_handle, static_cast<DataType>(types[i]));
702    }
703  
704    ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
705  }
706  
707  // Helpers for loading a TensorFlow plugin (a .so file).
708  Status LoadLibrary(const char* library_filename, void** result,
709                     const void** buf, size_t* len);
710  
711  }  // namespace tensorflow
712  
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)713  static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
714                           TF_Status* status) {
715    status->status = Status::OK();
716    for (int i = 0; i < noutputs; ++i) {
717      c_outputs[i] = nullptr;
718    }
719  }
720  
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)721  static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
722                            std::vector<std::pair<string, Tensor>>* input_pairs,
723                            TF_Status* status) {
724    const int ninputs = input_pairs->size();
725    for (int i = 0; i < ninputs; ++i) {
726      status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
727      if (!status->status.ok()) return false;
728    }
729    return true;
730  }
731  
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)732  static void TF_Run_Helper(
733      Session* session, const char* handle, const TF_Buffer* run_options,
734      // Input tensors
735      const std::vector<std::pair<string, Tensor>>& input_pairs,
736      // Output tensors
737      const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
738      // Target nodes
739      const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
740      TF_Status* status) {
741    const int noutputs = output_tensor_names.size();
742    std::vector<Tensor> outputs(noutputs);
743    Status result;
744  
745    if (handle == nullptr) {
746      RunOptions run_options_proto;
747      if (run_options != nullptr && !run_options_proto.ParseFromArray(
748                                        run_options->data, run_options->length)) {
749        status->status = InvalidArgument("Unparseable RunOptions proto");
750        return;
751      }
752      if (run_metadata != nullptr && run_metadata->data != nullptr) {
753        status->status =
754            InvalidArgument("Passing non-empty run_metadata is invalid.");
755        return;
756      }
757  
758      RunMetadata run_metadata_proto;
759      result = session->Run(run_options_proto, input_pairs, output_tensor_names,
760                            target_oper_names, &outputs, &run_metadata_proto);
761  
762      // Serialize back to upstream client, who now owns the new buffer
763      if (run_metadata != nullptr) {
764        status->status = MessageToBuffer(run_metadata_proto, run_metadata);
765        if (!status->status.ok()) return;
766      }
767    } else {
768      // NOTE(zongheng): PRun does not support RunOptions yet.
769      result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
770    }
771    if (!result.ok()) {
772      status->status = result;
773      return;
774    }
775  
776    // Store results in c_outputs[]
777    for (int i = 0; i < noutputs; ++i) {
778      const Tensor& src = outputs[i];
779      if (!src.IsInitialized() || src.NumElements() == 0) {
780        c_outputs[i] =
781            EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
782        continue;
783      }
784      c_outputs[i] = TF_TensorFromTensor(src, status);
785      if (!status->status.ok()) return;
786    }
787  }
788  
789  extern "C" {
790  
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)791  void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
792              // Input tensors
793              const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
794              // Output tensors
795              const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
796              // Target nodes
797              const char** c_target_oper_names, int ntargets,
798              TF_Buffer* run_metadata, TF_Status* status) {
799    TF_Run_Setup(noutputs, c_outputs, status);
800    std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
801    if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
802    for (int i = 0; i < ninputs; ++i) {
803      input_pairs[i].first = c_input_names[i];
804    }
805    std::vector<string> output_names(noutputs);
806    for (int i = 0; i < noutputs; ++i) {
807      output_names[i] = c_output_names[i];
808    }
809    std::vector<string> target_oper_names(ntargets);
810    for (int i = 0; i < ntargets; ++i) {
811      target_oper_names[i] = c_target_oper_names[i];
812    }
813    TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
814                  c_outputs, target_oper_names, run_metadata, status);
815  }
816  
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)817  void TF_PRunSetup(TF_DeprecatedSession* s,
818                    // Input names
819                    const char** c_input_names, int ninputs,
820                    // Output names
821                    const char** c_output_names, int noutputs,
822                    // Target nodes
823                    const char** c_target_oper_names, int ntargets,
824                    const char** handle, TF_Status* status) {
825    *handle = nullptr;
826  
827    std::vector<string> input_names(ninputs);
828    std::vector<string> output_names(noutputs);
829    std::vector<string> target_oper_names(ntargets);
830    for (int i = 0; i < ninputs; ++i) {
831      input_names[i] = c_input_names[i];
832    }
833    for (int i = 0; i < noutputs; ++i) {
834      output_names[i] = c_output_names[i];
835    }
836    for (int i = 0; i < ntargets; ++i) {
837      target_oper_names[i] = c_target_oper_names[i];
838    }
839    string new_handle;
840    status->status = s->session->PRunSetup(input_names, output_names,
841                                           target_oper_names, &new_handle);
842    if (status->status.ok()) {
843      char* buf = new char[new_handle.size() + 1];
844      memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
845      *handle = buf;
846    }
847  }
848  
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)849  void TF_PRun(TF_DeprecatedSession* s, const char* handle,
850               // Input tensors
851               const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
852               // Output tensors
853               const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
854               // Target nodes
855               const char** c_target_oper_names, int ntargets,
856               TF_Status* status) {
857    TF_Run_Setup(noutputs, c_outputs, status);
858    std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
859    if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
860    for (int i = 0; i < ninputs; ++i) {
861      input_pairs[i].first = c_input_names[i];
862    }
863  
864    std::vector<string> output_names(noutputs);
865    for (int i = 0; i < noutputs; ++i) {
866      output_names[i] = c_output_names[i];
867    }
868    std::vector<string> target_oper_names(ntargets);
869    for (int i = 0; i < ntargets; ++i) {
870      target_oper_names[i] = c_target_oper_names[i];
871    }
872    TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
873                  c_outputs, target_oper_names, nullptr, status);
874  }
875  
TF_LoadLibrary(const char * library_filename,TF_Status * status)876  TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
877    TF_Library* lib_handle = new TF_Library;
878    status->status = tensorflow::LoadLibrary(
879        library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
880        &lib_handle->op_list.length);
881    if (!status->status.ok()) {
882      delete lib_handle;
883      return nullptr;
884    }
885    return lib_handle;
886  }
887  
TF_GetOpList(TF_Library * lib_handle)888  TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
889  
TF_DeleteLibraryHandle(TF_Library * lib_handle)890  void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
891    tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
892    delete lib_handle;
893  }
894  
TF_GetAllOpList()895  TF_Buffer* TF_GetAllOpList() {
896    std::vector<tensorflow::OpDef> op_defs;
897    tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
898    tensorflow::OpList op_list;
899    for (const auto& op : op_defs) {
900      *(op_list.add_op()) = op;
901    }
902    TF_Buffer* ret = TF_NewBuffer();
903    TF_CHECK_OK(MessageToBuffer(op_list, ret));
904    return ret;
905  }
906  
907  // --------------------------------------------------------------------------
908  // ListDevices & SessionListDevices API
909  
TF_DeleteDeviceList(TF_DeviceList * s)910  void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
911  
TF_SessionListDevices(TF_Session * session,TF_Status * status)912  TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
913    TF_DeviceList* response = new TF_DeviceList;
914    status->status = session->session->ListDevices(&response->response);
915    return response;
916  }
917  
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)918  TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
919                                                 TF_Status* status) {
920    TF_DeviceList* response = new TF_DeviceList;
921    status->status = session->session->ListDevices(&response->response);
922    return response;
923  }
924  
TF_DeviceListCount(const TF_DeviceList * list)925  int TF_DeviceListCount(const TF_DeviceList* list) {
926    return list->response.size();
927  }
928  
929  #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
930    return_type method_name(const TF_DeviceList* list, const int index,     \
931                            TF_Status* status) {                            \
932      if (list == nullptr) {                                                \
933        status->status = InvalidArgument("list is null!");                  \
934        return err_val;                                                     \
935      }                                                                     \
936      if (index < 0 || index >= list->response.size()) {                    \
937        status->status = InvalidArgument("index out of bounds");            \
938        return err_val;                                                     \
939      }                                                                     \
940      status->status = Status::OK();                                        \
941      return list->response[index].accessor;                                \
942    }
943  
944  TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
945  TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
946                       nullptr);
947  TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
948  
949  #undef TF_DEVICELIST_METHOD
950  
951  }  // end extern "C"
952  
953  // --------------------------------------------------------------------------
954  // New Graph and Session API
955  
956  // Helper functions -----------------------------------------------------------
957  
958  namespace {
959  
ToOperation(Node * node)960  TF_Operation* ToOperation(Node* node) {
961    return static_cast<TF_Operation*>(static_cast<void*>(node));
962  }
963  
OutputName(const TF_Output & output)964  string OutputName(const TF_Output& output) {
965    return StrCat(output.oper->node.name(), ":", output.index);
966  }
967  
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)968  const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
969                                            const char* attr_name,
970                                            TF_Status* status) {
971    const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
972    if (attr == nullptr) {
973      status->status = InvalidArgument("Operation '", oper->node.name(),
974                                       "' has no attr named '", attr_name, "'.");
975    }
976    return attr;
977  }
978  
ToTensorId(const TF_Output & output)979  TensorId ToTensorId(const TF_Output& output) {
980    return TensorId(output.oper->node.name(), output.index);
981  }
982  
983  #ifndef __ANDROID__
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)984  std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
985                                                       int n) {
986    std::vector<tensorflow::Output> outputs(n);
987    for (int i = 0; i < n; ++i) {
988      outputs[i] =
989          tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
990    }
991    return outputs;
992  }
993  
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)994  void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
995                            TF_Output* tf_outputs) {
996    for (int i = 0; i < outputs.size(); i++) {
997      tf_outputs[i].oper = ToOperation(outputs[i].node());
998      tf_outputs[i].index = outputs[i].index();
999    }
1000  }
1001  #endif  // __ANDROID__
1002  
1003  }  // namespace
1004  
1005  // Shape functions -----------------------------------------------------------
1006  
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)1007  void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1008                              const int64_t* dims, const int num_dims,
1009                              TF_Status* status) {
1010    Node* node = &output.oper->node;
1011  
1012    mutex_lock l(graph->mu);
1013    tensorflow::shape_inference::InferenceContext* ic =
1014        graph->refiner.GetContext(node);
1015    if (ic == nullptr) {
1016      status->status =
1017          InvalidArgument("Node ", node->name(), " was not found in the graph");
1018      return;
1019    }
1020    tensorflow::shape_inference::ShapeHandle new_shape =
1021        tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1022    status->status = graph->refiner.SetShape(node, output.index, new_shape);
1023  }
1024  
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)1025  int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1026                               TF_Status* status) {
1027    Node* node = &output.oper->node;
1028  
1029    mutex_lock l(graph->mu);
1030    tensorflow::shape_inference::InferenceContext* ic =
1031        graph->refiner.GetContext(node);
1032    if (ic == nullptr) {
1033      status->status =
1034          InvalidArgument("Node ", node->name(), " was not found in the graph");
1035      return -1;
1036    }
1037  
1038    tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1039  
1040    // Unknown rank means the number of dimensions is -1.
1041    if (!ic->RankKnown(shape)) {
1042      return -1;
1043    }
1044  
1045    return ic->Rank(shape);
1046  }
1047  
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)1048  void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1049                              int num_dims, TF_Status* status) {
1050    Node* node = &output.oper->node;
1051  
1052    mutex_lock l(graph->mu);
1053    tensorflow::shape_inference::InferenceContext* ic =
1054        graph->refiner.GetContext(node);
1055    if (ic == nullptr) {
1056      status->status =
1057          InvalidArgument("Node ", node->name(), " was not found in the graph");
1058      return;
1059    }
1060  
1061    tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1062  
1063    int rank = -1;
1064    if (ic->RankKnown(shape)) {
1065      rank = ic->Rank(shape);
1066    }
1067  
1068    if (num_dims != rank) {
1069      status->status = InvalidArgument("Expected rank is ", num_dims,
1070                                       " but actual rank is ", rank);
1071      return;
1072    }
1073  
1074    if (num_dims == 0) {
1075      // Output shape is a scalar.
1076      return;
1077    }
1078  
1079    // Rank is greater than 0, so fill in the values, if known, and
1080    // -1 for unknown values.
1081    for (int i = 0; i < num_dims; ++i) {
1082      auto dim = ic->Dim(shape, i);
1083      tensorflow::int64 value = -1;
1084      if (ic->ValueKnown(dim)) {
1085        value = ic->Value(dim);
1086      }
1087      dims[i] = value;
1088    }
1089  }
1090  
1091  // TF_OperationDescription functions ------------------------------------------
1092  
1093  extern "C" {
1094  
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)1095  static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1096                                                        const char* op_type,
1097                                                        const char* oper_name)
1098      EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1099    return new TF_OperationDescription(graph, op_type, oper_name);
1100  }
1101  
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)1102  TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1103                                           const char* oper_name) {
1104    mutex_lock l(graph->mu);
1105    return TF_NewOperationLocked(graph, op_type, oper_name);
1106  }
1107  
TF_SetDevice(TF_OperationDescription * desc,const char * device)1108  void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1109    desc->node_builder.Device(device);
1110  }
1111  
TF_AddInput(TF_OperationDescription * desc,TF_Output input)1112  void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1113    desc->node_builder.Input(&input.oper->node, input.index);
1114  }
1115  
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)1116  void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1117                       int num_inputs) {
1118    std::vector<NodeBuilder::NodeOut> input_list;
1119    input_list.reserve(num_inputs);
1120    for (int i = 0; i < num_inputs; ++i) {
1121      input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1122    }
1123    desc->node_builder.Input(input_list);
1124  }
1125  
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)1126  void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1127    desc->node_builder.ControlInput(&input->node);
1128  }
1129  
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)1130  void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1131    desc->colocation_constraints.emplace(
1132        StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1133  }
1134  
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)1135  void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1136                        const void* value, size_t length) {
1137    tensorflow::StringPiece s(static_cast<const char*>(value), length);
1138    desc->node_builder.Attr(attr_name, s);
1139  }
1140  
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1141  void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1142                            const void* const* values, const size_t* lengths,
1143                            int num_values) {
1144    if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1145      desc->colocation_constraints.clear();
1146      for (int i = 0; i < num_values; ++i) {
1147        desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1148                                             lengths[i]);
1149      }
1150    } else {
1151      std::vector<tensorflow::StringPiece> v;
1152      v.reserve(num_values);
1153      for (int i = 0; i < num_values; ++i) {
1154        v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1155      }
1156      desc->node_builder.Attr(attr_name, v);
1157    }
1158  }
1159  
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)1160  void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1161                     int64_t value) {
1162    static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1163                  "64-bit int types should match in size");
1164    desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1165  }
1166  
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)1167  void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1168                         const int64_t* values, int num_values) {
1169    static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1170                  "64-bit int types should match in size");
1171    desc->node_builder.Attr(
1172        attr_name,
1173        ArraySlice<const tensorflow::int64>(
1174            reinterpret_cast<const tensorflow::int64*>(values), num_values));
1175  }
1176  
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)1177  void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1178                       float value) {
1179    desc->node_builder.Attr(attr_name, value);
1180  }
1181  
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)1182  void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1183                           const float* values, int num_values) {
1184    desc->node_builder.Attr(attr_name,
1185                            ArraySlice<const float>(values, num_values));
1186  }
1187  
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)1188  void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1189                      unsigned char value) {
1190    desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1191  }
1192  
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)1193  void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1194                          const unsigned char* values, int num_values) {
1195    std::unique_ptr<bool[]> b(new bool[num_values]);
1196    for (int i = 0; i < num_values; ++i) {
1197      b[i] = values[i];
1198    }
1199    desc->node_builder.Attr(attr_name,
1200                            ArraySlice<const bool>(b.get(), num_values));
1201  }
1202  
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)1203  void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1204                      TF_DataType value) {
1205    desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1206  }
1207  
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)1208  void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1209                          const TF_DataType* values, int num_values) {
1210    desc->node_builder.Attr(
1211        attr_name, ArraySlice<const DataType>(
1212                       reinterpret_cast<const DataType*>(values), num_values));
1213  }
1214  
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)1215  void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1216                          const char* value, size_t length) {
1217    tensorflow::NameAttrList func_name;
1218    func_name.set_name(std::string(value, value + length));
1219    desc->node_builder.Attr(attr_name, func_name);
1220  }
1221  
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)1222  void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1223                       const int64_t* dims, int num_dims) {
1224    PartialTensorShape shape;
1225    if (num_dims >= 0) {
1226      static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1227                    "64-bit int types should match in size");
1228      shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1229          reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1230    }
1231    desc->node_builder.Attr(attr_name, shape);
1232  }
1233  
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)1234  void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1235                           const int64_t* const* dims, const int* num_dims,
1236                           int num_shapes) {
1237    std::vector<PartialTensorShape> shapes;
1238    shapes.reserve(num_shapes);
1239    for (int i = 0; i < num_shapes; ++i) {
1240      if (num_dims[i] < 0) {
1241        shapes.emplace_back();
1242      } else {
1243        static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1244                      "64-bit int types should match in size");
1245        shapes.emplace_back(ArraySlice<tensorflow::int64>(
1246            reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1247      }
1248    }
1249    desc->node_builder.Attr(attr_name, shapes);
1250  }
1251  
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1252  void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1253                                  const char* attr_name, const void* proto,
1254                                  size_t proto_len, TF_Status* status) {
1255    // shape.ParseFromArray takes an int as length, this function takes size_t,
1256    // make sure there is no information loss.
1257    if (proto_len > std::numeric_limits<int>::max()) {
1258      status->status = InvalidArgument(
1259          "proto_len (", proto_len,
1260          " bytes) is too large to be parsed by the protocol buffer library");
1261      return;
1262    }
1263    TensorShapeProto shape;
1264    if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1265      desc->node_builder.Attr(attr_name, shape);
1266      status->status = Status::OK();
1267    } else {
1268      status->status = InvalidArgument("Unparseable TensorShapeProto");
1269    }
1270  }
1271  
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)1272  void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1273                                      const char* attr_name,
1274                                      const void* const* protos,
1275                                      const size_t* proto_lens, int num_shapes,
1276                                      TF_Status* status) {
1277    std::vector<TensorShapeProto> shapes;
1278    shapes.resize(num_shapes);
1279    for (int i = 0; i < num_shapes; ++i) {
1280      if (proto_lens[i] > std::numeric_limits<int>::max()) {
1281        status->status = InvalidArgument(
1282            "length of element ", i, " in the list (", proto_lens[i],
1283            " bytes) is too large to be parsed by the protocol buffer library");
1284        return;
1285      }
1286      if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1287        status->status =
1288            InvalidArgument("Unparseable TensorShapeProto at index ", i);
1289        return;
1290      }
1291    }
1292    desc->node_builder.Attr(attr_name, shapes);
1293    status->status = Status::OK();
1294  }
1295  
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1296  void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1297                        TF_Tensor* value, TF_Status* status) {
1298    Tensor t;
1299    status->status = TF_TensorToTensor(value, &t);
1300    if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1301  }
1302  
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1303  void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1304                            TF_Tensor* const* values, int num_values,
1305                            TF_Status* status) {
1306    status->status = Status::OK();
1307    std::vector<Tensor> t;
1308    t.reserve(num_values);
1309  
1310    for (int i = 0; i < num_values && status->status.ok(); ++i) {
1311      Tensor v;
1312      status->status = TF_TensorToTensor(values[i], &v);
1313      t.emplace_back(v);
1314    }
1315  
1316    if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1317  }
1318  
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1319  void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1320                            const void* proto, size_t proto_len,
1321                            TF_Status* status) {
1322    tensorflow::AttrValue attr_value;
1323    if (!attr_value.ParseFromArray(proto, proto_len)) {
1324      status->status = InvalidArgument("Unparseable AttrValue proto");
1325      return;
1326    }
1327  
1328    if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1329      if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1330          attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1331        status->status =
1332            InvalidArgument("Expected \"list\" field for \"",
1333                            tensorflow::kColocationAttrName, "\" attribute");
1334        return;
1335      }
1336      desc->colocation_constraints.clear();
1337      for (const string& location : attr_value.list().s()) {
1338        desc->colocation_constraints.insert(location);
1339      }
1340    } else {
1341      desc->node_builder.Attr(attr_name, attr_value);
1342    }
1343  
1344    status->status = Status::OK();
1345  }
1346  
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1347  static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1348                                                TF_Status* status)
1349      EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1350    Node* ret = nullptr;
1351  
1352    if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1353      status->status = InvalidArgument("Duplicate node name in graph: '",
1354                                       desc->node_builder.node_name(), "'");
1355    } else {
1356      if (!desc->colocation_constraints.empty()) {
1357        desc->node_builder.Attr(
1358            tensorflow::kColocationAttrName,
1359            std::vector<string>(desc->colocation_constraints.begin(),
1360                                desc->colocation_constraints.end()));
1361      }
1362      status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1363  
1364      if (status->status.ok()) {
1365        // Run shape inference function for newly added node.
1366        status->status = desc->graph->refiner.AddNode(ret);
1367      }
1368      if (status->status.ok()) {
1369        // Add the node to the name-to-node mapping.
1370        desc->graph->name_map[ret->name()] = ret;
1371      } else if (ret != nullptr) {
1372        desc->graph->graph.RemoveNode(ret);
1373        ret = nullptr;
1374      }
1375    }
1376  
1377    delete desc;
1378  
1379    return ToOperation(ret);
1380  }
1381  
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1382  TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1383                                   TF_Status* status) {
1384    mutex_lock l(desc->graph->mu);
1385    return TF_FinishOperationLocked(desc, status);
1386  }
1387  
1388  // TF_Operation functions
1389  // ----------------------------------------------------------
1390  
TF_OperationName(TF_Operation * oper)1391  const char* TF_OperationName(TF_Operation* oper) {
1392    return oper->node.name().c_str();
1393  }
1394  
TF_OperationOpType(TF_Operation * oper)1395  const char* TF_OperationOpType(TF_Operation* oper) {
1396    return oper->node.type_string().c_str();
1397  }
1398  
TF_OperationDevice(TF_Operation * oper)1399  const char* TF_OperationDevice(TF_Operation* oper) {
1400    return oper->node.requested_device().c_str();
1401  }
1402  
TF_OperationNumOutputs(TF_Operation * oper)1403  int TF_OperationNumOutputs(TF_Operation* oper) {
1404    return oper->node.num_outputs();
1405  }
1406  
TF_OperationOutputType(TF_Output oper_out)1407  TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1408    return static_cast<TF_DataType>(
1409        oper_out.oper->node.output_type(oper_out.index));
1410  }
1411  
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1412  int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1413                                   TF_Status* status) {
1414    NameRangeMap name_ranges;
1415    status->status =
1416        NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1417    if (!status->status.ok()) return -1;
1418    auto iter = name_ranges.find(arg_name);
1419    if (iter == name_ranges.end()) {
1420      status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1421      return -1;
1422    }
1423    return iter->second.second - iter->second.first;
1424  }
1425  
TF_OperationNumInputs(TF_Operation * oper)1426  int TF_OperationNumInputs(TF_Operation* oper) {
1427    return oper->node.num_inputs();
1428  }
1429  
TF_OperationInputType(TF_Input oper_in)1430  TF_DataType TF_OperationInputType(TF_Input oper_in) {
1431    return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1432  }
1433  
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1434  int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1435                                  TF_Status* status) {
1436    NameRangeMap name_ranges;
1437    status->status =
1438        NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1439    if (!status->status.ok()) return -1;
1440    auto iter = name_ranges.find(arg_name);
1441    if (iter == name_ranges.end()) {
1442      status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1443      return -1;
1444    }
1445    return iter->second.second - iter->second.first;
1446  }
1447  
TF_OperationInput(TF_Input oper_in)1448  TF_Output TF_OperationInput(TF_Input oper_in) {
1449    const tensorflow::Edge* edge;
1450    Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1451    if (!s.ok()) {
1452      return {nullptr, -1};
1453    }
1454  
1455    return {ToOperation(edge->src()), edge->src_output()};
1456  }
1457  
TF_OperationOutputNumConsumers(TF_Output oper_out)1458  int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1459    int count = 0;
1460    for (const auto* edge : oper_out.oper->node.out_edges()) {
1461      if (edge->src_output() == oper_out.index) {
1462        ++count;
1463      }
1464    }
1465    return count;
1466  }
1467  
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1468  int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1469                                  int max_consumers) {
1470    int count = 0;
1471    for (const auto* edge : oper_out.oper->node.out_edges()) {
1472      if (edge->src_output() == oper_out.index) {
1473        if (count < max_consumers) {
1474          consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1475        }
1476        ++count;
1477      }
1478    }
1479    return count;
1480  }
1481  
TF_OperationNumControlInputs(TF_Operation * oper)1482  int TF_OperationNumControlInputs(TF_Operation* oper) {
1483    int count = 0;
1484    for (const auto* edge : oper->node.in_edges()) {
1485      if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1486        ++count;
1487      }
1488    }
1489    return count;
1490  }
1491  
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1492  int TF_OperationGetControlInputs(TF_Operation* oper,
1493                                   TF_Operation** control_inputs,
1494                                   int max_control_inputs) {
1495    int count = 0;
1496    for (const auto* edge : oper->node.in_edges()) {
1497      if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1498        if (count < max_control_inputs) {
1499          control_inputs[count] = ToOperation(edge->src());
1500        }
1501        ++count;
1502      }
1503    }
1504    return count;
1505  }
1506  
TF_OperationNumControlOutputs(TF_Operation * oper)1507  int TF_OperationNumControlOutputs(TF_Operation* oper) {
1508    int count = 0;
1509    for (const auto* edge : oper->node.out_edges()) {
1510      if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1511        ++count;
1512      }
1513    }
1514    return count;
1515  }
1516  
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1517  int TF_OperationGetControlOutputs(TF_Operation* oper,
1518                                    TF_Operation** control_outputs,
1519                                    int max_control_outputs) {
1520    int count = 0;
1521    for (const auto* edge : oper->node.out_edges()) {
1522      if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1523        if (count < max_control_outputs) {
1524          control_outputs[count] = ToOperation(edge->dst());
1525        }
1526        ++count;
1527      }
1528    }
1529    return count;
1530  }
1531  
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1532  TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1533                                              const char* attr_name,
1534                                              TF_Status* status) {
1535    TF_AttrMetadata metadata;
1536    const auto* attr = GetAttrValue(oper, attr_name, status);
1537    if (!status->status.ok()) return metadata;
1538    switch (attr->value_case()) {
1539  #define SINGLE_CASE(kK, attr_type, size_expr) \
1540    case tensorflow::AttrValue::kK:             \
1541      metadata.is_list = 0;                     \
1542      metadata.list_size = -1;                  \
1543      metadata.type = attr_type;                \
1544      metadata.total_size = size_expr;          \
1545      break;
1546  
1547      SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1548      SINGLE_CASE(kI, TF_ATTR_INT, -1);
1549      SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1550      SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1551      SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1552      SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1553                  attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1554      SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1555  #undef SINGLE_CASE
1556  
1557      case tensorflow::AttrValue::kList:
1558        metadata.is_list = 1;
1559        metadata.list_size = 0;
1560        metadata.total_size = -1;
1561  #define LIST_CASE(field, attr_type, ...)              \
1562    if (attr->list().field##_size() > 0) {              \
1563      metadata.type = attr_type;                        \
1564      metadata.list_size = attr->list().field##_size(); \
1565      __VA_ARGS__;                                      \
1566      break;                                            \
1567    }
1568  
1569        LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
1570                  for (int i = 0; i < attr->list().s_size();
1571                       ++i) { metadata.total_size += attr->list().s(i).size(); });
1572        LIST_CASE(i, TF_ATTR_INT);
1573        LIST_CASE(f, TF_ATTR_FLOAT);
1574        LIST_CASE(b, TF_ATTR_BOOL);
1575        LIST_CASE(type, TF_ATTR_TYPE);
1576        LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1577                  for (int i = 0; i < attr->list().shape_size(); ++i) {
1578                    const auto& s = attr->list().shape(i);
1579                    metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1580                  });
1581        LIST_CASE(tensor, TF_ATTR_TENSOR);
1582        LIST_CASE(tensor, TF_ATTR_FUNC);
1583  #undef LIST_CASE
1584        // All lists empty, determine the type from the OpDef.
1585        if (metadata.list_size == 0) {
1586          for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1587            const auto& a = oper->node.op_def().attr(i);
1588            if (a.name().compare(attr_name) != 0) continue;
1589            const string& typestr = a.type();
1590            if (typestr == "list(string)") {
1591              metadata.type = TF_ATTR_STRING;
1592            } else if (typestr == "list(int)") {
1593              metadata.type = TF_ATTR_INT;
1594            } else if (typestr == "list(float)") {
1595              metadata.type = TF_ATTR_FLOAT;
1596            } else if (typestr == "list(bool)") {
1597              metadata.type = TF_ATTR_BOOL;
1598            } else if (typestr == "list(type)") {
1599              metadata.type = TF_ATTR_TYPE;
1600            } else if (typestr == "list(shape)") {
1601              metadata.type = TF_ATTR_SHAPE;
1602            } else if (typestr == "list(tensor)") {
1603              metadata.type = TF_ATTR_TENSOR;
1604            } else if (typestr == "list(func)") {
1605              metadata.type = TF_ATTR_FUNC;
1606            } else {
1607              status->status = InvalidArgument(
1608                  "Attribute '", attr_name,
1609                  "' has an empty value of an unrecognized type '", typestr, "'");
1610              return metadata;
1611            }
1612          }
1613        }
1614        break;
1615  
1616      case tensorflow::AttrValue::kPlaceholder:
1617        metadata.is_list = 0;
1618        metadata.list_size = -1;
1619        metadata.type = TF_ATTR_PLACEHOLDER;
1620        metadata.total_size = -1;
1621        break;
1622  
1623      case tensorflow::AttrValue::kFunc:
1624        metadata.is_list = 0;
1625        metadata.list_size = -1;
1626        metadata.type = TF_ATTR_FUNC;
1627        metadata.total_size = -1;
1628        break;
1629  
1630      case tensorflow::AttrValue::VALUE_NOT_SET:
1631        status->status =
1632            InvalidArgument("Attribute '", attr_name, "' has no value set");
1633        break;
1634    }
1635    return metadata;
1636  }
1637  
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1638  void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1639                                 void* value, size_t max_length,
1640                                 TF_Status* status) {
1641    const auto* attr = GetAttrValue(oper, attr_name, status);
1642    if (!status->status.ok()) return;
1643    if (attr->value_case() != tensorflow::AttrValue::kS) {
1644      status->status =
1645          InvalidArgument("Attribute '", attr_name, "' is not a string");
1646      return;
1647    }
1648    if (max_length <= 0) {
1649      return;
1650    }
1651    const auto& s = attr->s();
1652    std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1653  }
1654  
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1655  void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1656                                     void** values, size_t* lengths,
1657                                     int max_values, void* storage,
1658                                     size_t storage_size, TF_Status* status) {
1659    const auto* attr = GetAttrValue(oper, attr_name, status);
1660    if (!status->status.ok()) return;
1661    if (attr->value_case() != tensorflow::AttrValue::kList) {
1662      status->status =
1663          InvalidArgument("Value for '", attr_name, "' is not a list");
1664      return;
1665    }
1666    const auto len = std::min(max_values, attr->list().s_size());
1667    char* p = static_cast<char*>(storage);
1668    for (int i = 0; i < len; ++i) {
1669      const string& s = attr->list().s(i);
1670      values[i] = p;
1671      lengths[i] = s.size();
1672      if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1673        status->status = InvalidArgument(
1674            "Not enough storage to hold the requested list of strings");
1675        return;
1676      }
1677      memcpy(values[i], s.data(), s.size());
1678      p += s.size();
1679    }
1680  }
1681  
1682  #define DEFINE_GETATTR(func, c_type, cpp_type, list_field)                   \
1683    void func(TF_Operation* oper, const char* attr_name, c_type* value,        \
1684              TF_Status* status) {                                             \
1685      cpp_type v;                                                              \
1686      status->status =                                                         \
1687          tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v);          \
1688      *value = static_cast<c_type>(v);                                         \
1689    }                                                                          \
1690    void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1691                    int max_values, TF_Status* status) {                       \
1692      const auto* attr = GetAttrValue(oper, attr_name, status);                \
1693      if (!status->status.ok()) return;                                        \
1694      if (attr->value_case() != tensorflow::AttrValue::kList) {                \
1695        status->status =                                                       \
1696            InvalidArgument("Value for '", attr_name, "' is not a list.");     \
1697        return;                                                                \
1698      }                                                                        \
1699      const auto len = std::min(max_values, attr->list().list_field##_size()); \
1700      for (int i = 0; i < len; ++i) {                                          \
1701        values[i] = static_cast<c_type>(attr->list().list_field(i));           \
1702      }                                                                        \
1703    }
1704  DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1705  DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1706  DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1707  DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1708  #undef DEFINE_GETATTR
1709  
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1710  void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1711                                int64_t* value, int num_dims, TF_Status* status) {
1712    PartialTensorShape shape;
1713    status->status =
1714        tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1715    if (!status->status.ok()) return;
1716    auto len = std::min(shape.dims(), num_dims);
1717    for (int i = 0; i < len; ++i) {
1718      value[i] = shape.dim_size(i);
1719    }
1720  }
1721  
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** values,int * num_dims,int max_values,int64_t * storage,int storage_size,TF_Status * status)1722  void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1723                                    int64_t** values, int* num_dims,
1724                                    int max_values, int64_t* storage,
1725                                    int storage_size, TF_Status* status) {
1726    std::vector<PartialTensorShape> shapes;
1727    status->status =
1728        tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1729    if (!status->status.ok()) return;
1730    auto len = std::min(static_cast<int>(shapes.size()), max_values);
1731    int64_t* p = storage;
1732    int storage_left = storage_size;
1733    for (int i = 0; i < len; ++i) {
1734      // shapes[i].dims() == -1 for shapes with an unknown rank.
1735      int64_t n = shapes[i].dims();
1736      num_dims[i] = n;
1737      values[i] = p;
1738      if (n < 0) {
1739        continue;
1740      }
1741      if (storage_left < n) {
1742        status->status = InvalidArgument(
1743            "Not enough storage to hold the requested list of shapes");
1744        return;
1745      }
1746      storage_left -= n;
1747      for (int j = 0; j < n; ++j, ++p) {
1748        *p = shapes[i].dim_size(j);
1749      }
1750    }
1751  }
1752  
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1753  void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1754                                           const char* attr_name,
1755                                           TF_Buffer* value, TF_Status* status) {
1756    const auto* attr = GetAttrValue(oper, attr_name, status);
1757    if (!status->status.ok()) return;
1758    if (attr->value_case() != tensorflow::AttrValue::kShape) {
1759      status->status =
1760          InvalidArgument("Value for '", attr_name, "' is not a shape.");
1761      return;
1762    }
1763    status->status = MessageToBuffer(attr->shape(), value);
1764  }
1765  
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1766  void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1767                                               const char* attr_name,
1768                                               TF_Buffer** values, int max_values,
1769                                               TF_Status* status) {
1770    const auto* attr = GetAttrValue(oper, attr_name, status);
1771    if (!status->status.ok()) return;
1772    if (attr->value_case() != tensorflow::AttrValue::kList) {
1773      status->status =
1774          InvalidArgument("Value for '", attr_name, "' is not a list");
1775      return;
1776    }
1777    const auto len = std::min(max_values, attr->list().shape_size());
1778    for (int i = 0; i < len; ++i) {
1779      values[i] = TF_NewBuffer();
1780      status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1781      if (!status->status.ok()) {
1782        // Delete everything allocated to far, the operation has failed.
1783        for (int j = 0; j <= i; ++j) {
1784          TF_DeleteBuffer(values[j]);
1785        }
1786        return;
1787      }
1788    }
1789  }
1790  
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1791  void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1792                                 TF_Tensor** value, TF_Status* status) {
1793    *value = nullptr;
1794    Tensor t;
1795    status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1796    if (!status->status.ok()) return;
1797    *value = TF_TensorFromTensor(t, status);
1798  }
1799  
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1800  void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1801                                     TF_Tensor** values, int max_values,
1802                                     TF_Status* status) {
1803    std::vector<Tensor> ts;
1804    status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1805    if (!status->status.ok()) return;
1806    const auto len = std::min(max_values, static_cast<int>(ts.size()));
1807    for (int i = 0; i < len; ++i) {
1808      values[i] = TF_TensorFromTensor(ts[i], status);
1809    }
1810  }
1811  
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1812  void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1813                                     TF_Buffer* output_attr_value,
1814                                     TF_Status* status) {
1815    const auto* attr = GetAttrValue(oper, attr_name, status);
1816    if (!status->status.ok()) return;
1817    status->status = MessageToBuffer(*attr, output_attr_value);
1818  }
1819  
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1820  void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1821                             TF_Status* status) {
1822    status->status = MessageToBuffer(oper->node.def(), output_node_def);
1823  }
1824  
1825  // TF_Graph functions ---------------------------------------------------------
1826  
TF_Graph()1827  TF_Graph::TF_Graph()
1828      : graph(tensorflow::OpRegistry::Global()),
1829        refiner(graph.versions().producer(), graph.op_registry()),
1830        delete_requested(false),
1831        parent(nullptr),
1832        parent_inputs(nullptr) {}
1833  
TF_NewGraph()1834  TF_Graph* TF_NewGraph() { return new TF_Graph; }
1835  
TF_DeleteGraph(TF_Graph * g)1836  void TF_DeleteGraph(TF_Graph* g) {
1837    g->mu.lock();
1838    g->delete_requested = true;
1839    const bool del = g->sessions.empty();
1840    g->mu.unlock();
1841    if (del) delete g;
1842  }
1843  
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1844  TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1845    mutex_lock l(graph->mu);
1846    auto iter = graph->name_map.find(oper_name);
1847    if (iter == graph->name_map.end()) {
1848      return nullptr;
1849    } else {
1850      return ToOperation(iter->second);
1851    }
1852  }
1853  
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1854  TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1855    if (*pos == 0) {
1856      // Advance past the first sentinel nodes in every graph (the source & sink).
1857      *pos += 2;
1858    } else {
1859      // Advance to the next node.
1860      *pos += 1;
1861    }
1862  
1863    mutex_lock l(graph->mu);
1864    while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1865      Node* node = graph->graph.FindNodeId(*pos);
1866      // FindNodeId() returns nullptr for nodes that have been deleted.
1867      // We aren't currently allowing nodes to be deleted, but it is safer
1868      // to still check.
1869      if (node != nullptr) return ToOperation(node);
1870      *pos += 1;
1871    }
1872  
1873    // No more nodes.
1874    return nullptr;
1875  }
1876  
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1877  void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1878                          TF_Status* status) {
1879    GraphDef def;
1880    {
1881      mutex_lock l(graph->mu);
1882      graph->graph.ToGraphDef(&def);
1883    }
1884    status->status = MessageToBuffer(def, output_graph_def);
1885  }
1886  
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1887  void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1888                        TF_Buffer* output_op_def, TF_Status* status) {
1889    const OpDef* op_def;
1890    {
1891      mutex_lock l(graph->mu);
1892      status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1893      if (!status->status.ok()) return;
1894    }
1895    status->status = MessageToBuffer(*op_def, output_op_def);
1896  }
1897  
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1898  void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1899                        TF_Status* status) {
1900    VersionDef versions;
1901    {
1902      mutex_lock l(graph->mu);
1903      versions = graph->graph.versions();
1904    }
1905    status->status = MessageToBuffer(versions, output_version_def);
1906  }
1907  
TF_NewImportGraphDefOptions()1908  TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1909    return new TF_ImportGraphDefOptions;
1910  }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1911  void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1912    delete opts;
1913  }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1914  void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1915                                         const char* prefix) {
1916    opts->opts.prefix = prefix;
1917  }
1918  
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1919  void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1920                                                unsigned char uniquify_names) {
1921    opts->opts.uniquify_names = uniquify_names;
1922  }
1923  
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1924  void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1925                                                 unsigned char uniquify_prefix) {
1926    opts->opts.uniquify_prefix = uniquify_prefix;
1927  }
1928  
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1929  void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1930                                               const char* src_name,
1931                                               int src_index, TF_Output dst) {
1932    opts->tensor_id_data.push_back(src_name);
1933    const string& src_name_str = opts->tensor_id_data.back();
1934    // We don't need to store dst's name in tensor_id_data, since `dst` must
1935    // outlive the ImportGraphDef call.
1936    opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1937  }
1938  
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1939  void TF_ImportGraphDefOptionsRemapControlDependency(
1940      TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1941    opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1942        TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1943  }
1944  
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1945  extern void TF_ImportGraphDefOptionsAddControlDependency(
1946      TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1947    opts->opts.control_dependencies.push_back(oper->node.name());
1948  }
1949  
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1950  void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1951                                               const char* oper_name, int index) {
1952    opts->tensor_id_data.push_back(oper_name);
1953    const string& oper_name_str = opts->tensor_id_data.back();
1954    opts->opts.return_tensors.emplace_back(oper_name_str, index);
1955  }
1956  
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1957  int TF_ImportGraphDefOptionsNumReturnOutputs(
1958      const TF_ImportGraphDefOptions* opts) {
1959    return opts->opts.return_tensors.size();
1960  }
1961  
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1962  void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1963                                                  const char* oper_name) {
1964    opts->opts.return_nodes.push_back(oper_name);
1965  }
1966  
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1967  int TF_ImportGraphDefOptionsNumReturnOperations(
1968      const TF_ImportGraphDefOptions* opts) {
1969    return opts->opts.return_nodes.size();
1970  }
1971  
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1972  void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1973                                             int* num_outputs,
1974                                             TF_Output** outputs) {
1975    *num_outputs = results->return_tensors.size();
1976    *outputs = results->return_tensors.data();
1977  }
1978  
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1979  void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1980                                                int* num_opers,
1981                                                TF_Operation*** opers) {
1982    *num_opers = results->return_nodes.size();
1983    *opers = results->return_nodes.data();
1984  }
1985  
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1986  void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1987      TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1988      const char*** src_names, int** src_indexes) {
1989    *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1990    *src_names = results->missing_unused_key_names.data();
1991    *src_indexes = results->missing_unused_key_indexes.data();
1992  }
1993  
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1994  void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1995    delete results;
1996  }
1997  
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1998  static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1999                                        const TF_ImportGraphDefOptions* opts,
2000                                        TF_ImportGraphDefResults* tf_results,
2001                                        TF_Status* status)
2002      EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2003    const int last_node_id = graph->graph.num_node_ids();
2004    tensorflow::ImportGraphDefResults results;
2005    status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2006                                                &graph->refiner, &results);
2007    if (!status->status.ok()) return;
2008  
2009    // Add new nodes to name_map
2010    for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2011      auto* node = graph->graph.FindNodeId(i);
2012      if (node != nullptr) graph->name_map[node->name()] = node;
2013    }
2014  
2015    // Populate return_tensors
2016    DCHECK(tf_results->return_tensors.empty());
2017    tf_results->return_tensors.resize(results.return_tensors.size());
2018    for (int i = 0; i < results.return_tensors.size(); ++i) {
2019      tf_results->return_tensors[i].oper =
2020          ToOperation(results.return_tensors[i].first);
2021      tf_results->return_tensors[i].index = results.return_tensors[i].second;
2022    }
2023  
2024    // Populate return_nodes
2025    DCHECK(tf_results->return_nodes.empty());
2026    tf_results->return_nodes.resize(results.return_nodes.size());
2027    for (int i = 0; i < results.return_nodes.size(); ++i) {
2028      tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2029    }
2030  
2031    // Populate missing unused map keys
2032    DCHECK(tf_results->missing_unused_key_names.empty());
2033    DCHECK(tf_results->missing_unused_key_indexes.empty());
2034    DCHECK(tf_results->missing_unused_key_names_data.empty());
2035  
2036    size_t size = results.missing_unused_input_map_keys.size();
2037    tf_results->missing_unused_key_names.resize(size);
2038    tf_results->missing_unused_key_indexes.resize(size);
2039  
2040    for (int i = 0; i < size; ++i) {
2041      TensorId id = results.missing_unused_input_map_keys[i];
2042      tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
2043      tf_results->missing_unused_key_names[i] =
2044          tf_results->missing_unused_key_names_data.back().c_str();
2045      tf_results->missing_unused_key_indexes[i] = id.second;
2046    }
2047  }
2048  
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2049  TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2050      TF_Graph* graph, const TF_Buffer* graph_def,
2051      const TF_ImportGraphDefOptions* options, TF_Status* status) {
2052    GraphDef def;
2053    if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2054      status->status = InvalidArgument("Invalid GraphDef");
2055      return nullptr;
2056    }
2057    auto results = new TF_ImportGraphDefResults();
2058    mutex_lock l(graph->mu);
2059    GraphImportGraphDefLocked(graph, def, options, results, status);
2060    if (!status->status.ok()) {
2061      delete results;
2062      return nullptr;
2063    }
2064    return results;
2065  }
2066  
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)2067  void TF_GraphImportGraphDefWithReturnOutputs(
2068      TF_Graph* graph, const TF_Buffer* graph_def,
2069      const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2070      int num_return_outputs, TF_Status* status) {
2071    if (num_return_outputs != options->opts.return_tensors.size()) {
2072      status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2073                                       options->opts.return_tensors.size(),
2074                                       ", got ", num_return_outputs);
2075      return;
2076    }
2077    if (num_return_outputs > 0 && return_outputs == nullptr) {
2078      status->status = InvalidArgument(
2079          "'return_outputs' must be preallocated to length ", num_return_outputs);
2080      return;
2081    }
2082    GraphDef def;
2083    if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2084      status->status = InvalidArgument("Invalid GraphDef");
2085      return;
2086    }
2087    TF_ImportGraphDefResults results;
2088    mutex_lock l(graph->mu);
2089    GraphImportGraphDefLocked(graph, def, options, &results, status);
2090    DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2091    memcpy(return_outputs, results.return_tensors.data(),
2092           num_return_outputs * sizeof(TF_Output));
2093  }
2094  
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2095  void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2096                              const TF_ImportGraphDefOptions* options,
2097                              TF_Status* status) {
2098    TF_ImportGraphDefResults* results =
2099        TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2100    TF_DeleteImportGraphDefResults(results);
2101  }
2102  
2103  // While loop functions -------------------------------------------------------
2104  
2105  namespace {
2106  
2107  #ifndef __ANDROID__
2108  
2109  // Creates a placeholder representing an input to the cond or body graph.
2110  // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)2111  bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2112                   TF_Output* input, TF_Status* status) {
2113    TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2114    TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2115    // TODO(skyewm): set placeholder shape
2116    TF_Operation* oper = TF_FinishOperation(desc, status);
2117    if (!status->status.ok()) return false;
2118    *input = {oper, 0};
2119    return true;
2120  }
2121  
2122  // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2123  // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`.  `prefix`
2124  // will be prepended to copied node names. `control_deps` are nodes in
2125  // `dst_graph` that the copied `src_graph` nodes will have control dependencies
2126  // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2127  // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)2128  Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2129                   tensorflow::ShapeRefiner* dst_refiner,
2130                   const TF_Output* src_inputs,
2131                   const std::vector<tensorflow::Output>& dst_inputs,
2132                   const string& prefix,
2133                   const std::vector<tensorflow::Operation>& control_deps,
2134                   const TF_Output* nodes_to_return, int nreturn_nodes,
2135                   std::vector<tensorflow::Output>* return_nodes) {
2136    DCHECK(return_nodes != nullptr);
2137    GraphDef gdef;
2138    src_graph->ToGraphDef(&gdef);
2139  
2140    tensorflow::ImportGraphDefOptions opts;
2141    opts.prefix = prefix;
2142  
2143    for (int i = 0; i < dst_inputs.size(); ++i) {
2144      opts.input_map[ToTensorId(src_inputs[i])] =
2145          TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2146    }
2147    opts.skip_mapped_nodes = true;
2148  
2149    for (const tensorflow::Operation& op : control_deps) {
2150      opts.control_dependencies.push_back(op.node()->name());
2151    }
2152  
2153    for (int i = 0; i < nreturn_nodes; ++i) {
2154      opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2155    }
2156  
2157    // TODO(skyewm): change to OutputTensor
2158    tensorflow::ImportGraphDefResults results;
2159    TF_RETURN_IF_ERROR(
2160        ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2161  
2162    for (const auto& pair : results.return_tensors) {
2163      return_nodes->emplace_back(pair.first, pair.second);
2164    }
2165    return Status::OK();
2166  }
2167  
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)2168  bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2169    if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2170        params.cond_graph->parent == nullptr ||
2171        params.cond_graph->parent != params.body_graph->parent ||
2172        params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2173        params.ninputs <= 0 || params.cond_inputs == nullptr ||
2174        params.body_inputs == nullptr || params.body_outputs == nullptr) {
2175      s->status = InvalidArgument(
2176          "TF_WhileParams must be created by successful TF_NewWhile() call");
2177      return false;
2178    }
2179    return true;
2180  }
2181  
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)2182  bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2183    if (params.cond_output.oper == nullptr) {
2184      s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2185      return false;
2186    }
2187    for (int i = 0; i < params.ninputs; ++i) {
2188      if (params.body_outputs[i].oper == nullptr) {
2189        s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2190                                    "field isn't set");
2191        return false;
2192      }
2193    }
2194    if (params.name == nullptr) {
2195      s->status = InvalidArgument("TF_WhileParams `name` field is null");
2196      return false;
2197    }
2198    return true;
2199  }
2200  
2201  #endif  // __ANDROID__
2202  
FreeWhileResources(const TF_WhileParams * params)2203  void FreeWhileResources(const TF_WhileParams* params) {
2204    TF_DeleteGraph(params->cond_graph);
2205    TF_DeleteGraph(params->body_graph);
2206    delete[] params->cond_inputs;
2207    delete[] params->body_inputs;
2208    delete[] params->body_outputs;
2209  }
2210  
EmptyWhileParams()2211  TF_WhileParams EmptyWhileParams() {
2212    return {0,       nullptr, nullptr, {nullptr, 0},
2213            nullptr, nullptr, nullptr, nullptr};
2214  }
2215  
2216  }  // namespace
2217  
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)2218  TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2219                             TF_Status* status) {
2220  #ifdef __ANDROID__
2221    status->status = tensorflow::errors::Unimplemented(
2222        "Creating while loops is not supported in Android. File a bug at "
2223        "https://github.com/tensorflow/tensorflow/issues if this feature is "
2224        "important to you");
2225    return EmptyWhileParams();
2226  #else
2227    if (ninputs == 0) {
2228      status->status =
2229          InvalidArgument("TF_NewWhile() must be passed at least one input");
2230      return EmptyWhileParams();
2231    }
2232  
2233    TF_Graph* cond_graph = TF_NewGraph();
2234    TF_Graph* body_graph = TF_NewGraph();
2235    cond_graph->parent = g;
2236    cond_graph->parent_inputs = inputs;
2237    body_graph->parent = g;
2238    body_graph->parent_inputs = inputs;
2239  
2240    TF_Output* cond_inputs = new TF_Output[ninputs];
2241    TF_Output cond_output = {nullptr, -1};
2242    TF_Output* body_inputs = new TF_Output[ninputs];
2243    TF_Output* body_outputs = new TF_Output[ninputs];
2244    for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2245    const char* name = nullptr;
2246  
2247    for (int i = 0; i < ninputs; ++i) {
2248      // TODO(skyewm): prefix names with underscore (requires some plumbing)
2249      if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2250                       &cond_inputs[i], status)) {
2251        break;
2252      }
2253      if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2254                       &body_inputs[i], status)) {
2255        break;
2256      }
2257    }
2258  
2259    TF_WhileParams params = {ninputs,    cond_graph,  cond_inputs,  cond_output,
2260                             body_graph, body_inputs, body_outputs, name};
2261  
2262    if (!status->status.ok()) {
2263      FreeWhileResources(&params);
2264      return EmptyWhileParams();
2265    }
2266    return params;
2267  #endif  // __ANDROID__
2268  }
2269  
2270  #ifndef __ANDROID__
2271  namespace {
2272  
2273  // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2274  void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2275                            TF_Output* outputs) {
2276    if (!ValidateInputWhileParams(*params, status)) return;
2277  
2278    TF_Graph* parent = params->cond_graph->parent;
2279    TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2280    int num_loop_vars = params->ninputs;
2281  
2282    mutex_lock l(parent->mu);
2283  
2284    // 'cond_fn' copies the cond graph into the parent graph.
2285    tensorflow::ops::CondGraphBuilderFn cond_fn =
2286        [params, parent](const tensorflow::Scope& scope,
2287                         const std::vector<tensorflow::Output>& inputs,
2288                         tensorflow::Output* output) {
2289          DCHECK_EQ(scope.graph(), &parent->graph);
2290          std::vector<tensorflow::Output> cond_output;
2291          TF_RETURN_IF_ERROR(CopyGraph(
2292              &params->cond_graph->graph, &parent->graph, &parent->refiner,
2293              params->cond_inputs, inputs, scope.impl()->name(),
2294              scope.impl()->control_deps(), &params->cond_output,
2295              /* nreturn_nodes */ 1, &cond_output));
2296          *output = cond_output[0];
2297          return Status::OK();
2298        };
2299  
2300    // 'body_fn' copies the body graph into the parent graph.
2301    tensorflow::ops::BodyGraphBuilderFn body_fn =
2302        [params, parent, num_loop_vars](
2303            const tensorflow::Scope& scope,
2304            const std::vector<tensorflow::Output>& inputs,
2305            std::vector<tensorflow::Output>* outputs) {
2306          DCHECK_EQ(scope.graph(), &parent->graph);
2307          TF_RETURN_IF_ERROR(
2308              CopyGraph(&params->body_graph->graph, &parent->graph,
2309                        &parent->refiner, params->body_inputs, inputs,
2310                        scope.impl()->name(), scope.impl()->control_deps(),
2311                        params->body_outputs, num_loop_vars, outputs));
2312          return Status::OK();
2313        };
2314  
2315    // Create the while loop using an internal scope.
2316    tensorflow::Scope scope =
2317        NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2318            .NewSubScope(params->name);
2319  
2320    const int first_new_node_id = parent->graph.num_node_ids();
2321  
2322    tensorflow::OutputList loop_outputs;
2323    status->status = tensorflow::ops::BuildWhileLoop(
2324        scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2325        body_fn, params->name, &loop_outputs);
2326  
2327    // Update name_map with newly-created ops.
2328    // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2329    // a bad status. Once we fix this, we may want to return early instead of
2330    // executing the following code.
2331    for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2332      Node* new_node = parent->graph.FindNodeId(i);
2333      if (new_node == nullptr) continue;
2334      parent->name_map[new_node->name()] = new_node;
2335    }
2336  
2337    // Populate 'outputs'.
2338    DCHECK_LE(loop_outputs.size(), num_loop_vars);
2339    for (int i = 0; i < loop_outputs.size(); ++i) {
2340      outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2341    }
2342  }
2343  
2344  }  // namespace
2345  #endif  // __ANDROID__
2346  
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2347  void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2348                      TF_Output* outputs) {
2349  #ifdef __ANDROID__
2350    status->status = tensorflow::errors::Unimplemented(
2351        "Creating while loops is not supported in Android. File a bug at "
2352        "https://github.com/tensorflow/tensorflow/issues if this feature is "
2353        "important to you");
2354  #else
2355    // If it appears the caller created or modified `params`, don't free resources
2356    if (!ValidateConstWhileParams(*params, status)) return;
2357    TF_FinishWhileHelper(params, status, outputs);
2358    FreeWhileResources(params);
2359  #endif  // __ANDROID__
2360  }
2361  
TF_AbortWhile(const TF_WhileParams * params)2362  void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2363  
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2364  void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2365                       TF_Output* dx, TF_Status* status, TF_Output* dy) {
2366  #ifdef __ANDROID__
2367    status->status = tensorflow::errors::Unimplemented(
2368        "Adding gradients is not supported in Android. File a bug at "
2369        "https://github.com/tensorflow/tensorflow/issues if this feature is "
2370        "important to you");
2371  #else
2372    std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2373    std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2374    std::vector<tensorflow::Output> dy_arg;
2375  
2376    {
2377      // We need to hold on to the lock while we have a scope that uses TF_Graph.
2378      mutex_lock graph_lock(g->mu);
2379  
2380      const int first_new_node_id = g->graph.num_node_ids();
2381  
2382      tensorflow::Scope scope =
2383          NewInternalScope(&g->graph, &status->status, &g->refiner)
2384              .NewSubScope("gradients");
2385  
2386      if (dx != nullptr) {
2387        std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2388        status->status =
2389            AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2390      } else {
2391        status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2392      }
2393  
2394      // Update g->name_map with the name_map from the scope, which will contain
2395      // the new gradient ops.
2396      for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2397        Node* n = g->graph.FindNodeId(i);
2398        if (n == nullptr) continue;
2399        g->name_map[n->name()] = n;
2400      }
2401    }
2402  
2403    // Unpack the results from grad_outputs_arg.
2404    TFOutputsFromOutputs(dy_arg, dy);
2405  #endif  // __ANDROID__
2406  }
2407  
2408  // TF_Session functions ----------------------------------------------
2409  
TF_Session(tensorflow::Session * s,TF_Graph * g)2410  TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2411      : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) {
2412    if (s->LocalDeviceManager(&device_mgr).ok()) {
2413      devices = device_mgr->ListDevices();
2414    }
2415  }
2416  
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2417  TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2418                            TF_Status* status) {
2419    Session* session;
2420    status->status = NewSession(opt->options, &session);
2421    if (status->status.ok()) {
2422      TF_Session* new_session = new TF_Session(session, graph);
2423      if (graph != nullptr) {
2424        mutex_lock l(graph->mu);
2425        graph->sessions[new_session] = Status::OK();
2426      }
2427      return new_session;
2428    } else {
2429      DCHECK_EQ(nullptr, session);
2430      return nullptr;
2431    }
2432  }
2433  
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2434  TF_Session* TF_LoadSessionFromSavedModel(
2435      const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2436      const char* export_dir, const char* const* tags, int tags_len,
2437      TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2438  // TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
2439  // the tensorflow/cc/saved_model:loader build target is Android friendly.
2440  #ifdef __ANDROID__
2441    status->status = tensorflow::errors::Unimplemented(
2442        "Loading a SavedModel is not supported in Android. File a bug at "
2443        "https://github.com/tensorflow/tensorflow/issues if this feature is "
2444        "important to you");
2445    return nullptr;
2446  #else
2447    mutex_lock l(graph->mu);
2448    if (!graph->name_map.empty()) {
2449      status->status = InvalidArgument("Graph is non-empty.");
2450      return nullptr;
2451    }
2452  
2453    RunOptions run_options_proto;
2454    if (run_options != nullptr && !run_options_proto.ParseFromArray(
2455                                      run_options->data, run_options->length)) {
2456      status->status = InvalidArgument("Unparseable RunOptions proto");
2457      return nullptr;
2458    }
2459  
2460    std::unordered_set<string> tag_set;
2461    for (int i = 0; i < tags_len; i++) {
2462      tag_set.insert(string(tags[i]));
2463    }
2464  
2465    tensorflow::SavedModelBundle bundle;
2466    status->status =
2467        tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2468                                   export_dir, tag_set, &bundle);
2469    if (!status->status.ok()) return nullptr;
2470  
2471    // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2472    // extends using GraphDefs. The Graph instance is different, but equivalent
2473    // to the one used to create the session.
2474    //
2475    // TODO(jhseu): When Session is modified to take Graphs instead of
2476    // GraphDefs, return the Graph generated in LoadSavedModel().
2477    TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2478    TF_ImportGraphDefResults results;
2479    GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2480                              import_opts, &results, status);
2481    TF_DeleteImportGraphDefOptions(import_opts);
2482    if (TF_GetCode(status) != TF_OK) return nullptr;
2483  
2484    if (meta_graph_def != nullptr) {
2485      status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2486      if (!status->status.ok()) return nullptr;
2487    }
2488  
2489    TF_Session* session = new TF_Session(bundle.session.release(), graph);
2490  
2491    graph->sessions[session] = Status::OK();
2492    session->last_num_graph_nodes = graph->graph.num_node_ids();
2493    return session;
2494  #endif  // __ANDROID__
2495  }
2496  
TF_CloseSession(TF_Session * s,TF_Status * status)2497  void TF_CloseSession(TF_Session* s, TF_Status* status) {
2498    status->status = s->session->Close();
2499  }
2500  
TF_DeleteSession(TF_Session * s,TF_Status * status)2501  void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2502    status->status = Status::OK();
2503    TF_Graph* const graph = s->graph;
2504    if (graph != nullptr) {
2505      graph->mu.lock();
2506      graph->sessions.erase(s);
2507      const bool del = graph->delete_requested && graph->sessions.empty();
2508      graph->mu.unlock();
2509      if (del) delete graph;
2510    }
2511    delete s->session;
2512    delete s;
2513  }
2514  
2515  // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2516  // directly, instead of requiring us to serialize to a GraphDef and
2517  // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)2518  static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
2519    if (session->graph != nullptr) {
2520      mutex_lock session_lock(session->mu);
2521      session->graph->mu.lock();
2522      const Graph& graph = session->graph->graph;
2523  
2524      status->status = session->graph->sessions[session];
2525      if (!status->status.ok()) {
2526        session->graph->mu.unlock();
2527        return false;
2528      }
2529  
2530      const auto num_nodes = graph.num_node_ids();
2531      if (session->last_num_graph_nodes < num_nodes) {
2532        status->status = tensorflow::ValidateNoCycles(session->graph->graph);
2533        if (!status->status.ok()) {
2534          session->graph->mu.unlock();
2535          return false;
2536        }
2537  
2538        GraphDef graph_def;
2539        *graph_def.mutable_versions() = graph.versions();
2540        // Fill graph_def with nodes with ids in the range
2541        // [session->last_num_graph_nodes, num_nodes), that is the nodes
2542        // added since the last TF_SessionRun() call.
2543        for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
2544          Node* const node = graph.FindNodeId(id);
2545          if (node != nullptr && node->IsOp()) {
2546            NodeDef* const node_def = graph_def.add_node();
2547            *node_def = node->def();
2548          }
2549        }
2550        *graph_def.mutable_library() = graph.flib_def().ToProto();
2551        session->graph->mu.unlock();
2552        status->status = session->session->Extend(graph_def);
2553        if (!status->status.ok()) {
2554          // Contract is we always delete input_values[i].
2555          return false;
2556        }
2557        // Note: session->session is not modified if Extend() fails, so
2558        // we only set last_num_graph_nodes if it succeeds.
2559        session->last_num_graph_nodes = num_nodes;
2560      } else {
2561        session->graph->mu.unlock();
2562      }
2563    }
2564    return true;
2565  }
2566  
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2567  void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2568                     const TF_Output* inputs, TF_Tensor* const* input_values,
2569                     int ninputs, const TF_Output* outputs,
2570                     TF_Tensor** output_values, int noutputs,
2571                     const TF_Operation* const* target_opers, int ntargets,
2572                     TF_Buffer* run_metadata, TF_Status* status) {
2573    // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2574    // directly, instead of requiring us to serialize to a GraphDef and
2575    // call Session::Extend().
2576    if (!ExtendSessionGraphHelper(session, status)) {
2577      return;
2578    }
2579  
2580    TF_Run_Setup(noutputs, output_values, status);
2581  
2582    // Convert from TF_Output and TF_Tensor to a string and Tensor.
2583    std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2584    if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2585    for (int i = 0; i < ninputs; ++i) {
2586      input_pairs[i].first = OutputName(inputs[i]);
2587    }
2588  
2589    // Convert from TF_Output to string names.
2590    std::vector<string> output_names(noutputs);
2591    for (int i = 0; i < noutputs; ++i) {
2592      output_names[i] = OutputName(outputs[i]);
2593    }
2594  
2595    // Convert from TF_Operation* to string names.
2596    std::vector<string> target_names(ntargets);
2597    for (int i = 0; i < ntargets; ++i) {
2598      target_names[i] = target_opers[i]->node.name();
2599    }
2600  
2601    // Actually run.
2602    TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2603                  output_names, output_values, target_names, run_metadata,
2604                  status);
2605  }
2606  
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2607  void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2608                           int ninputs, const TF_Output* outputs, int noutputs,
2609                           const TF_Operation* const* target_opers, int ntargets,
2610                           const char** handle, TF_Status* status) {
2611    *handle = nullptr;
2612  
2613    if (!ExtendSessionGraphHelper(session, status)) {
2614      return;
2615    }
2616  
2617    std::vector<string> input_names(ninputs);
2618    for (int i = 0; i < ninputs; ++i) {
2619      input_names[i] = OutputName(inputs[i]);
2620    }
2621  
2622    std::vector<string> output_names(noutputs);
2623    for (int i = 0; i < noutputs; ++i) {
2624      output_names[i] = OutputName(outputs[i]);
2625    }
2626  
2627    std::vector<string> target_names(ntargets);
2628    for (int i = 0; i < ntargets; ++i) {
2629      target_names[i] = target_opers[i]->node.name();
2630    }
2631  
2632    string new_handle;
2633    status->status = session->session->PRunSetup(input_names, output_names,
2634                                                 target_names, &new_handle);
2635    if (status->status.ok()) {
2636      char* buf = new char[new_handle.size() + 1];
2637      memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2638      *handle = buf;
2639    }
2640  }
2641  
TF_DeletePRunHandle(const char * handle)2642  void TF_DeletePRunHandle(const char* handle) {
2643    delete[] handle;
2644    // TODO(suharshs): Free up any resources held by the partial run state.
2645  }
2646  
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2647  void TF_SessionPRun(TF_Session* session, const char* handle,
2648                      const TF_Output* inputs, TF_Tensor* const* input_values,
2649                      int ninputs, const TF_Output* outputs,
2650                      TF_Tensor** output_values, int noutputs,
2651                      const TF_Operation* const* target_opers, int ntargets,
2652                      TF_Status* status) {
2653    // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2654    // directly, instead of requiring us to serialize to a GraphDef and
2655    // call Session::Extend().
2656    if (!ExtendSessionGraphHelper(session, status)) {
2657      return;
2658    }
2659  
2660    TF_Run_Setup(noutputs, output_values, status);
2661  
2662    // Convert from TF_Output and TF_Tensor to a string and Tensor.
2663    std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2664    if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2665    for (int i = 0; i < ninputs; ++i) {
2666      input_pairs[i].first = OutputName(inputs[i]);
2667    }
2668  
2669    // Convert from TF_Output to string names.
2670    std::vector<string> output_names(noutputs);
2671    for (int i = 0; i < noutputs; ++i) {
2672      output_names[i] = OutputName(outputs[i]);
2673    }
2674  
2675    // Convert from TF_Operation* to string names.
2676    std::vector<string> target_names(ntargets);
2677    for (int i = 0; i < ntargets; ++i) {
2678      target_names[i] = target_opers[i]->node.name();
2679    }
2680  
2681    TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2682                  output_values, target_names, nullptr, status);
2683  }
2684  
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2685  TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2686    tensorflow::OpList op_list;
2687    if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2688      status->status = InvalidArgument("Unparseable OpList");
2689      return nullptr;
2690    }
2691    status->status = Status::OK();
2692    return new TF_ApiDefMap(op_list);
2693  }
2694  
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2695  void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2696  
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2697  void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2698                       size_t text_len, TF_Status* status) {
2699  #ifdef __ANDROID__
2700    status->status = tensorflow::errors::Unimplemented(
2701        "ApiDefMap is not supported in Android.");
2702  #else
2703    mutex_lock l(api_def_map->lock);
2704    if (api_def_map->update_docs_called) {
2705      status->status = FailedPrecondition(
2706          "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2707          "called.");
2708      return;
2709    }
2710    string api_def_text(text, text_len);
2711    status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2712  #endif  // __ANDROID__
2713  }
2714  
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2715  TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2716                             size_t name_len, TF_Status* status) {
2717  #ifdef __ANDROID__
2718    status->status = tensorflow::errors::Unimplemented(
2719        "ApiDefMap is not supported in Android.");
2720    return nullptr;
2721  #else
2722    mutex_lock l(api_def_map->lock);
2723    if (!api_def_map->update_docs_called) {
2724      api_def_map->api_def_map.UpdateDocs();
2725      api_def_map->update_docs_called = true;
2726    }
2727    string name_str(name, name_len);
2728    const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2729  
2730    TF_Buffer* ret = TF_NewBuffer();
2731    status->status = MessageToBuffer(*api_def, ret);
2732    return ret;
2733  #endif  // __ANDROID__
2734  }
2735  }  // end extern "C"
2736