1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 // Required for IS_MOBILE_PLATFORM
24 #include "tensorflow/core/platform/platform.h" // NOLINT
25
26 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
27 #include "tensorflow/cc/framework/gradients.h"
28 #include "tensorflow/cc/framework/ops.h"
29 #include "tensorflow/cc/framework/scope_internal.h"
30 #include "tensorflow/cc/ops/while_loop.h"
31 #include "tensorflow/cc/saved_model/loader.h"
32 #include "tensorflow/core/distributed_runtime/server_lib.h"
33 #include "tensorflow/core/framework/op_gen_lib.h"
34 #include "tensorflow/core/kernels/logging_ops.h"
35 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
36 #include "tensorflow/c/c_api_internal.h"
37 #include "tensorflow/core/common_runtime/device_mgr.h"
38 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
39 #include "tensorflow/core/common_runtime/shape_refiner.h"
40 #include "tensorflow/core/framework/allocation_description.pb.h"
41 #include "tensorflow/core/framework/kernel_def.pb.h"
42 #include "tensorflow/core/framework/log_memory.h"
43 #include "tensorflow/core/framework/node_def_util.h"
44 #include "tensorflow/core/framework/op_kernel.h"
45 #include "tensorflow/core/framework/partial_tensor_shape.h"
46 #include "tensorflow/core/framework/tensor.h"
47 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
48 #include "tensorflow/core/framework/tensor_shape.h"
49 #include "tensorflow/core/framework/tensor_shape.pb.h"
50 #include "tensorflow/core/framework/types.h"
51 #include "tensorflow/core/framework/versions.pb.h"
52 #include "tensorflow/core/graph/graph.h"
53 #include "tensorflow/core/graph/graph_constructor.h"
54 #include "tensorflow/core/graph/node_builder.h"
55 #include "tensorflow/core/graph/validate.h"
56 #include "tensorflow/core/lib/core/coding.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/lib/core/stringpiece.h"
60 #include "tensorflow/core/lib/gtl/array_slice.h"
61 #include "tensorflow/core/lib/strings/str_util.h"
62 #include "tensorflow/core/lib/strings/strcat.h"
63 #include "tensorflow/core/platform/mem.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/protobuf.h"
66 #include "tensorflow/core/platform/thread_annotations.h"
67 #include "tensorflow/core/platform/types.h"
68 #include "tensorflow/core/public/session.h"
69 #include "tensorflow/core/public/version.h"
70
71 // The implementation below is at the top level instead of the
72 // brain namespace because we are defining 'extern "C"' functions.
73 using tensorflow::AllocationDescription;
74 using tensorflow::DataType;
75 using tensorflow::ExtendSessionGraphHelper;
76 using tensorflow::Graph;
77 using tensorflow::GraphDef;
78 using tensorflow::mutex_lock;
79 using tensorflow::NameRangeMap;
80 using tensorflow::NameRangesForNode;
81 using tensorflow::NewSession;
82 using tensorflow::Node;
83 using tensorflow::NodeBuilder;
84 using tensorflow::NodeDef;
85 using tensorflow::OpDef;
86 using tensorflow::OpRegistry;
87 using tensorflow::OutputTensor;
88 using tensorflow::PartialTensorShape;
89 using tensorflow::RunMetadata;
90 using tensorflow::RunOptions;
91 using tensorflow::Session;
92 using tensorflow::Status;
93 using tensorflow::string;
94 using tensorflow::Tensor;
95 using tensorflow::TensorBuffer;
96 using tensorflow::TensorId;
97 using tensorflow::TensorShape;
98 using tensorflow::TensorShapeProto;
99 using tensorflow::VersionDef;
100 using tensorflow::error::Code;
101 using tensorflow::errors::FailedPrecondition;
102 using tensorflow::errors::InvalidArgument;
103 using tensorflow::gtl::ArraySlice;
104 using tensorflow::strings::StrCat;
105
106 extern "C" {
107
108 // --------------------------------------------------------------------------
TF_Version()109 const char* TF_Version() { return TF_VERSION_STRING; }
110
111 // --------------------------------------------------------------------------
TF_DataTypeSize(TF_DataType dt)112 size_t TF_DataTypeSize(TF_DataType dt) {
113 return static_cast<size_t>(
114 tensorflow::DataTypeSize(static_cast<DataType>(dt)));
115 }
116
117 // --------------------------------------------------------------------------
118
TF_NewStatus()119 TF_Status* TF_NewStatus() { return new TF_Status; }
120
TF_DeleteStatus(TF_Status * s)121 void TF_DeleteStatus(TF_Status* s) { delete s; }
122
TF_SetStatus(TF_Status * s,TF_Code code,const char * msg)123 void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
124 if (code == TF_OK) {
125 s->status = Status::OK();
126 return;
127 }
128 s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
129 }
130
TF_GetCode(const TF_Status * s)131 TF_Code TF_GetCode(const TF_Status* s) {
132 return static_cast<TF_Code>(s->status.code());
133 }
134
TF_Message(const TF_Status * s)135 const char* TF_Message(const TF_Status* s) {
136 return s->status.error_message().c_str();
137 }
138
139 // --------------------------------------------------------------------------
140
141 namespace {
142 class TF_ManagedBuffer : public TensorBuffer {
143 public:
TF_ManagedBuffer(void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)144 TF_ManagedBuffer(void* data, size_t len,
145 void (*deallocator)(void* data, size_t len, void* arg),
146 void* deallocator_arg)
147 : TensorBuffer(data),
148 len_(len),
149 deallocator_(deallocator),
150 deallocator_arg_(deallocator_arg) {}
151
152 const size_t len_;
153 void (*const deallocator_)(void* data, size_t len, void* arg);
154 void* const deallocator_arg_;
155
~TF_ManagedBuffer()156 ~TF_ManagedBuffer() override {
157 (*deallocator_)(data(), len_, deallocator_arg_);
158 }
159
size() const160 size_t size() const override { return len_; }
root_buffer()161 TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const162 void FillAllocationDescription(AllocationDescription* proto) const override {
163 tensorflow::int64 rb = size();
164 proto->set_requested_bytes(rb);
165 proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
166 }
167
168 // Prevents input forwarding from mutating this buffer.
OwnsMemory() const169 bool OwnsMemory() const override { return false; }
170 };
171
allocate_tensor(const char * operation,size_t len)172 void* allocate_tensor(const char* operation, size_t len) {
173 void* data =
174 tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
175 if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
176 tensorflow::LogMemory::RecordRawAllocation(
177 operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
178 len, data, tensorflow::cpu_allocator());
179 }
180 return data;
181 }
182
deallocate_buffer(void * data,size_t len,void * arg)183 void deallocate_buffer(void* data, size_t len, void* arg) {
184 if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
185 tensorflow::LogMemory::RecordRawDeallocation(
186 "TensorFlow C Api",
187 tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
188 tensorflow::cpu_allocator(), false);
189 }
190 tensorflow::cpu_allocator()->DeallocateRaw(data);
191 }
192
193 } // namespace
194
~TF_Tensor()195 TF_Tensor::~TF_Tensor() { buffer->Unref(); }
196
TF_AllocateTensor(TF_DataType dtype,const int64_t * dims,int num_dims,size_t len)197 TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
198 int num_dims, size_t len) {
199 void* data = allocate_tensor("TF_AllocateTensor", len);
200 return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
201 nullptr);
202 }
203
TF_NewTensor(TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)204 TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
205 void* data, size_t len,
206 void (*deallocator)(void* data, size_t len, void* arg),
207 void* deallocator_arg) {
208 std::vector<tensorflow::int64> dimvec(num_dims);
209 for (int i = 0; i < num_dims; ++i) {
210 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
211 }
212
213 TF_ManagedBuffer* buf = nullptr;
214 if (dtype != TF_STRING && dtype != TF_RESOURCE &&
215 tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
216 reinterpret_cast<intptr_t>(data) % std::max(1, EIGEN_MAX_ALIGN_BYTES) !=
217 0) {
218 // TF_STRING and TF_RESOURCE tensors have a different representation in
219 // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
220 // (any alignment requirements will be taken care of by TF_TensorToTensor
221 // and TF_TensorFromTensor).
222 //
223 // Other types have the same representation, so copy only if it is safe to
224 // do so.
225 buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len,
226 deallocate_buffer, nullptr);
227 std::memcpy(buf->data(), data, len);
228 // Free the original buffer.
229 deallocator(data, len, deallocator_arg);
230 } else {
231 buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg);
232 }
233
234 TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
235 size_t elem_size = TF_DataTypeSize(dtype);
236 if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
237 delete ret;
238 return nullptr;
239 }
240 return ret;
241 }
242
TF_TensorMaybeMove(TF_Tensor * tensor)243 TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
244 // It is safe to move the Tensor if and only if we own the unique reference to
245 // it. In that case, we might as well not delete and reallocate, but a future
246 // implementation might need to do so.
247 TensorBuffer* buf = tensor->buffer;
248 if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
249 buf->OwnsMemory()) {
250 return tensor;
251 }
252 return nullptr;
253 }
254
TF_DeleteTensor(TF_Tensor * t)255 void TF_DeleteTensor(TF_Tensor* t) { delete t; }
256
TF_TensorType(const TF_Tensor * t)257 TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
TF_NumDims(const TF_Tensor * t)258 int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
TF_Dim(const TF_Tensor * t,int dim_index)259 int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
260 return static_cast<int64_t>(t->shape.dim_size(dim_index));
261 }
TF_TensorByteSize(const TF_Tensor * t)262 size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
TF_TensorData(const TF_Tensor * t)263 void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
264
TF_TensorElementCount(const TF_Tensor * t)265 int64_t TF_TensorElementCount(const TF_Tensor* t) {
266 int64_t result = 1;
267 int rank = TF_NumDims(t);
268 for (int dim = 0; dim < rank; ++dim) {
269 result *= TF_Dim(t, dim);
270 }
271 return result;
272 }
273
274 // Returns the number of elements that would be present in a tensor with the
275 // given shape.
ShapeNumElements(const int64_t * dims,int num_dims)276 static int64_t ShapeNumElements(const int64_t* dims, int num_dims) {
277 int64_t result = 1;
278 for (int dim = 0; dim < num_dims; ++dim) {
279 result *= dims[dim];
280 }
281 return result;
282 }
283
UnrefIfNonNull(::tensorflow::TensorBuffer * buf)284 static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) {
285 if (buf != nullptr) {
286 buf->Unref();
287 }
288 }
289
RefIfNonNull(::tensorflow::TensorBuffer * buf)290 static void RefIfNonNull(::tensorflow::TensorBuffer* buf) {
291 if (buf != nullptr) {
292 buf->Ref();
293 }
294 }
295
TF_TensorBitcastFrom(const TF_Tensor * from,TF_DataType type,TF_Tensor * to,const int64_t * new_dims,int num_new_dims,TF_Status * status)296 void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type,
297 TF_Tensor* to, const int64_t* new_dims,
298 int num_new_dims, TF_Status* status) {
299 TF_SetStatus(status, TF_OK, "");
300 size_t in_size = TF_DataTypeSize(TF_TensorType(from));
301 if (in_size == 0) {
302 TF_SetStatus(status, TF_INVALID_ARGUMENT,
303 "input tensor has a zero-sized data type");
304 return;
305 }
306 size_t out_size = TF_DataTypeSize(type);
307 if (out_size == 0) {
308 TF_SetStatus(status, TF_INVALID_ARGUMENT,
309 "output tensor has a zero-sized data type");
310 return;
311 }
312
313 if (ShapeNumElements(new_dims, num_new_dims) * out_size !=
314 TF_TensorElementCount(from) * in_size) {
315 TF_SetStatus(status, TF_INVALID_ARGUMENT,
316 "input tensor is not compatible with output shape");
317 return;
318 }
319
320 tensorflow::TensorShapeProto p;
321 for (int i = 0; i < num_new_dims; ++i) {
322 p.add_dim()->set_size(new_dims[i]);
323 }
324 to->shape = tensorflow::TensorShape(p);
325 to->dtype = type;
326 if (to->buffer != from->buffer) {
327 UnrefIfNonNull(to->buffer);
328 to->buffer = from->buffer;
329 RefIfNonNull(to->buffer);
330 }
331 }
332
333 // --------------------------------------------------------------------------
TF_StringEncode(const char * src,size_t src_len,char * dst,size_t dst_len,TF_Status * status)334 size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
335 size_t dst_len, TF_Status* status) {
336 const size_t sz = TF_StringEncodedSize(src_len);
337 if (sz < src_len) {
338 status->status = InvalidArgument("src string is too large to encode");
339 return 0;
340 }
341 if (dst_len < sz) {
342 status->status =
343 InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
344 src_len, "-byte string");
345 return 0;
346 }
347 dst = tensorflow::core::EncodeVarint64(dst, src_len);
348 memcpy(dst, src, src_len);
349 return sz;
350 }
351
TF_StringDecode_Impl(const char * src,size_t src_len,const char ** dst,size_t * dst_len)352 static Status TF_StringDecode_Impl(const char* src, size_t src_len,
353 const char** dst, size_t* dst_len) {
354 tensorflow::uint64 len64 = 0;
355 const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
356 if (p == nullptr) {
357 return InvalidArgument("invalid string encoding or truncated src buffer");
358 }
359 if (len64 > std::numeric_limits<size_t>::max()) {
360 return InvalidArgument("encoded string is ", len64,
361 "-bytes, which is too large for this architecture");
362 }
363 *dst = p;
364 *dst_len = static_cast<size_t>(len64);
365 return Status::OK();
366 }
367
TF_StringDecode(const char * src,size_t src_len,const char ** dst,size_t * dst_len,TF_Status * status)368 size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
369 size_t* dst_len, TF_Status* status) {
370 status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
371 if (TF_GetCode(status) != TF_OK) return 0;
372 return static_cast<size_t>(*dst - src) + *dst_len;
373 }
374
TF_StringEncodedSize(size_t len)375 size_t TF_StringEncodedSize(size_t len) {
376 return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
377 }
378
379 // --------------------------------------------------------------------------
TF_NewSessionOptions()380 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)381 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
382
TF_SetTarget(TF_SessionOptions * options,const char * target)383 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
384 options->options.target = target;
385 }
386
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)387 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
388 size_t proto_len, TF_Status* status) {
389 if (!options->options.config.ParseFromArray(proto, proto_len)) {
390 status->status = InvalidArgument("Unparseable ConfigProto");
391 }
392 }
393 // --------------------------------------------------------------------------
TF_NewBuffer()394 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
395
TF_NewBufferFromString(const void * proto,size_t proto_len)396 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
397 void* copy = tensorflow::port::Malloc(proto_len);
398 memcpy(copy, proto, proto_len);
399
400 TF_Buffer* buf = new TF_Buffer;
401 buf->data = copy;
402 buf->length = proto_len;
403 buf->data_deallocator = [](void* data, size_t length) {
404 tensorflow::port::Free(data);
405 };
406 return buf;
407 }
408
TF_DeleteBuffer(TF_Buffer * buffer)409 void TF_DeleteBuffer(TF_Buffer* buffer) {
410 if (buffer == nullptr) return;
411 if (buffer->data_deallocator != nullptr) {
412 (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
413 buffer->length);
414 }
415 delete buffer;
416 }
417
TF_GetBuffer(TF_Buffer * buffer)418 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
419
420 // --------------------------------------------------------------------------
421
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)422 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
423 TF_Status* status) {
424 Session* session;
425 status->status = NewSession(opt->options, &session);
426 if (TF_GetCode(status) == TF_OK) {
427 return new TF_DeprecatedSession({session});
428 } else {
429 DCHECK_EQ(nullptr, session);
430 return nullptr;
431 }
432 }
433
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)434 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
435 status->status = s->session->Close();
436 }
437
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)438 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
439 status->status = Status::OK();
440 if (s == nullptr) return;
441 delete s->session;
442 delete s;
443 }
444
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)445 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
446 size_t proto_len, TF_Status* status) {
447 GraphDef g;
448 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
449 status->status = InvalidArgument("Invalid GraphDef");
450 return;
451 }
452 status->status = s->session->Extend(g);
453 }
454
DeleteArray(void * data,size_t size,void * arg)455 static void DeleteArray(void* data, size_t size, void* arg) {
456 DCHECK_EQ(data, arg);
457 delete[] reinterpret_cast<char*>(arg);
458 }
459
460 } // end extern "C"
461
462 namespace tensorflow {
463 namespace {
464
465 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)466 void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
467 int ncontainers, TF_Status* status) {
468 std::vector<string> container_names(ncontainers);
469 for (int i = 0; i < ncontainers; ++i) {
470 container_names[i] = containers[i];
471 }
472
473 status->status = Reset(opt->options, container_names);
474 }
475
476 } // namespace
477 } // namespace tensorflow
478
479 extern "C" {
480
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)481 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
482 int ncontainers, TF_Status* status) {
483 tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
484 }
485
486 } // end extern "C"
487
488 namespace tensorflow {
489
TF_TensorToTensor(const TF_Tensor * src,Tensor * dst)490 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
491 if (src->dtype == TF_RESOURCE) {
492 if (src->shape.dims() != 0) {
493 return InvalidArgument(
494 "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
495 "shape ",
496 src->shape.DebugString());
497 }
498 *dst = Tensor(DT_RESOURCE, src->shape);
499 if (!dst->scalar<ResourceHandle>()().ParseFromString(
500 string(static_cast<const char*>(TF_TensorData(src)),
501 TF_TensorByteSize(src)))) {
502 return InvalidArgument(
503 "Malformed TF_RESOUCE tensor: unable to parse resource handle");
504 }
505 return Status::OK();
506 }
507 if (src->dtype != TF_STRING) {
508 *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
509 return Status::OK();
510 }
511 // TF_STRING tensors require copying since Tensor class expects a sequence of
512 // string objects.
513 const tensorflow::int64 num_elements = src->shape.num_elements();
514 const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
515 const size_t src_size = TF_TensorByteSize(src);
516 if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
517 num_elements) {
518 return InvalidArgument(
519 "Malformed TF_STRING tensor; too short to hold number of elements");
520 }
521 const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
522 const char* limit = input + src_size;
523
524 *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
525 auto dstarray = dst->flat<string>();
526 for (tensorflow::int64 i = 0; i < num_elements; ++i) {
527 tensorflow::uint64 offset =
528 reinterpret_cast<const tensorflow::uint64*>(input)[i];
529 if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
530 return InvalidArgument("Malformed TF_STRING tensor; element ", i,
531 " out of range");
532 }
533 size_t len;
534 const char* p;
535 const char* srcp = data_start + offset;
536 Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
537 if (!status.ok()) return status;
538 dstarray(i).assign(p, len);
539 }
540 return Status::OK();
541 }
542
543 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
544 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const TensorShape & shape)545 static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
546 static char empty;
547 tensorflow::int64 nelems = 1;
548 std::vector<tensorflow::int64> dims;
549 for (int i = 0; i < shape.dims(); ++i) {
550 dims.push_back(shape.dim_size(i));
551 nelems *= shape.dim_size(i);
552 }
553 CHECK_EQ(nelems, 0);
554 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
555 "64-bit int types should match in size");
556 return TF_NewTensor(
557 dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
558 reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
559 }
560
561 // Non-static for testing.
TF_TensorFromTensor(const tensorflow::Tensor & src,TF_Status * status)562 TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
563 TF_Status* status) {
564 TF_SetStatus(status, TF_OK, "");
565 if (!src.IsInitialized()) {
566 status->status = FailedPrecondition(
567 "attempt to use a tensor with an uninitialized value");
568 return nullptr;
569 }
570 if (src.NumElements() == 0) {
571 return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
572 }
573 if (src.dtype() == DT_RESOURCE) {
574 if (src.shape().dims() != 0) {
575 status->status = InvalidArgument(
576 "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
577 src.shape().DebugString(),
578 "). Please file a bug at "
579 "https://github.com/tensorflow/tensorflow/issues/new, "
580 "ideally with a "
581 "short code snippet that reproduces this error.");
582 return nullptr;
583 }
584 const string str = src.scalar<ResourceHandle>()().SerializeAsString();
585 TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
586 std::memcpy(TF_TensorData(t), str.c_str(), str.size());
587 return t;
588 }
589 if (src.dtype() != DT_STRING) {
590 TensorBuffer* buf = TensorCApi::Buffer(src);
591 buf->Ref();
592 return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
593 buf};
594 }
595 // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
596 // encoded sequence of strings.
597
598 // Compute bytes needed for encoding.
599 size_t size = 0;
600 const auto& srcarray = src.flat<string>();
601 for (int i = 0; i < srcarray.size(); ++i) {
602 const string& s = srcarray(i);
603 // uint64 starting_offset, TF_StringEncode-d string.
604 size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
605 }
606
607 // Encode all strings.
608 char* base = new char[size];
609 char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
610 char* dst = data_start; // Where next string is encoded.
611 size_t dst_len = size - static_cast<size_t>(data_start - base);
612 tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
613 for (int i = 0; i < srcarray.size(); ++i) {
614 *offsets = (dst - data_start);
615 offsets++;
616 const string& s = srcarray(i);
617 size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
618 if (TF_GetCode(status) != TF_OK) {
619 status->status = InvalidArgument(
620 "invalid string tensor encoding (string #", i, " of ",
621 srcarray.size(), "): ", status->status.error_message());
622 delete[] base;
623 return nullptr;
624 }
625 dst += consumed;
626 dst_len -= consumed;
627 }
628 if (dst != base + size) {
629 status->status = InvalidArgument(
630 "invalid string tensor encoding (decoded ", (dst - base),
631 " bytes, but the tensor is encoded in ", size, " bytes");
632 delete[] base;
633 return nullptr;
634 }
635
636 auto dims = src.shape().dim_sizes();
637 std::vector<tensorflow::int64> dimvec(dims.size());
638 for (size_t i = 0; i < dims.size(); ++i) {
639 dimvec[i] = dims[i];
640 }
641 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
642 "64-bit int types should match in size");
643 return TF_NewTensor(TF_STRING,
644 reinterpret_cast<const int64_t*>(dimvec.data()),
645 dimvec.size(), base, size, DeleteArray, base);
646 }
647
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)648 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
649 TF_Buffer* out) {
650 if (out->data != nullptr) {
651 return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
652 }
653 const size_t proto_size = in.ByteSizeLong();
654 void* buf = tensorflow::port::Malloc(proto_size);
655 if (buf == nullptr) {
656 return tensorflow::errors::ResourceExhausted(
657 "Failed to allocate memory to serialize message of type '",
658 in.GetTypeName(), "' and size ", proto_size);
659 }
660 // SerializeToArray takes size as an int.
661 // This next 'if' is a workaround till we update to depend on a version
662 // of protocol buffers that includes
663 // https://github.com/google/protobuf/pull/4739
664 if (proto_size > std::numeric_limits<int>::max()) {
665 return InvalidArgument("Cannot serialize protocol buffer of type ",
666 in.GetTypeName(), " as the serialized size (",
667 proto_size,
668 "bytes) would be larger than the limit (",
669 std::numeric_limits<int>::max(), " bytes)");
670 }
671 if (!in.SerializeToArray(buf, proto_size)) {
672 return InvalidArgument("Unable to serialize ", in.GetTypeName(),
673 " protocol buffer, perhaps the serialized size (",
674 proto_size, " bytes) is too large?");
675 }
676 out->data = buf;
677 out->length = proto_size;
678 out->data_deallocator = [](void* data, size_t length) {
679 tensorflow::port::Free(data);
680 };
681 return Status::OK();
682 }
683
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)684 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
685 const char* mutation_type) {
686 // If any session has already run this node_id, mark this session as
687 // unrunnable.
688 for (auto it : graph->sessions) {
689 mutex_lock session_lock(it.first->mu);
690 if (it.first->last_num_graph_nodes > op.node.id()) {
691 it.second = strings::StrCat(
692 "Operation '", op.node.DebugString(), "' was changed by ",
693 mutation_type,
694 " after it was run by a session. This mutation will have no effect, "
695 "and will trigger an error in the future. Either don't modify "
696 "nodes after running them or create a new session.");
697 }
698 }
699 }
700
701 namespace {
702
703 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)704 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
705 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
706 const int64_t* dims) {
707 if (num_dims != -1) {
708 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
709 dim_vec.reserve(num_dims);
710 for (int i = 0; i < num_dims; ++i) {
711 dim_vec.push_back(ic->MakeDim(dims[i]));
712 }
713 return ic->MakeShape(dim_vec);
714 } else {
715 return ic->UnknownShape();
716 }
717 }
718
719 } // namespace
720
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)721 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
722 int num_shapes_and_types,
723 const int64_t** shapes,
724 const int* ranks,
725 const TF_DataType* types,
726 TF_Status* status) {
727 Node* node = &output.oper->node;
728
729 mutex_lock l(graph->mu);
730 tensorflow::shape_inference::InferenceContext* ic =
731 graph->refiner.GetContext(node);
732 if (ic == nullptr) {
733 status->status =
734 InvalidArgument("Node ", node->name(), " was not found in the graph");
735 return;
736 }
737
738 auto shape_and_type_vec =
739 std::vector<tensorflow::shape_inference::ShapeAndType>(
740 num_shapes_and_types);
741 for (int i = 0; i < num_shapes_and_types; ++i) {
742 tensorflow::shape_inference::ShapeHandle shape_handle =
743 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
744 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
745 shape_handle, static_cast<DataType>(types[i]));
746 }
747
748 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
749 }
750
751 // Helpers for loading a TensorFlow plugin (a .so file).
752 Status LoadLibrary(const char* library_filename, void** result,
753 const void** buf, size_t* len);
754
755 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
756 // directly, instead of requiring us to serialize to a GraphDef and
757 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)758 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
759 if (session->graph != nullptr) {
760 // Take the graph lock before the session lock to avoid deadlock. This is
761 // safe since session->graph does not change.
762 session->graph->mu.lock();
763 mutex_lock session_lock(session->mu);
764 const Graph& graph = session->graph->graph;
765
766 const string& mutation_warning = session->graph->sessions[session];
767 if (!mutation_warning.empty()) {
768 // TODO(b/74949947): turn this back into an error status
769 LOG(WARNING) << mutation_warning;
770 session->graph->sessions[session].clear();
771 }
772
773 const auto num_nodes = graph.num_node_ids();
774 if (session->last_num_graph_nodes < num_nodes) {
775 // TODO(nolivia): check this on a subset of the graph instead of all of
776 // it.
777 status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
778 if (TF_GetCode(status) != TF_OK) {
779 session->graph->mu.unlock();
780 return false;
781 }
782
783 GraphDef graph_def;
784 *graph_def.mutable_versions() = graph.versions();
785 // Fill graph_def with nodes with ids in the range
786 // [session->last_num_graph_nodes, num_nodes), that is the nodes
787 // added since the last TF_SessionRun() call.
788 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
789 Node* const node = graph.FindNodeId(id);
790 if (node != nullptr && node->IsOp()) {
791 NodeDef* const node_def = graph_def.add_node();
792 *node_def = node->def();
793 }
794 }
795 *graph_def.mutable_library() = graph.flib_def().ToProto();
796 session->graph->mu.unlock();
797 status->status = session->session->Extend(graph_def);
798 if (TF_GetCode(status) != TF_OK) {
799 // Contract is we always delete input_values[i].
800 return false;
801 }
802 // Note: session->session is not modified if Extend() fails, so
803 // we only set last_num_graph_nodes if it succeeds.
804 session->last_num_graph_nodes = num_nodes;
805 } else {
806 session->graph->mu.unlock();
807 }
808 }
809 return true;
810 }
811
812 } // namespace tensorflow
813
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)814 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
815 TF_Status* status) {
816 status->status = Status::OK();
817 for (int i = 0; i < noutputs; ++i) {
818 c_outputs[i] = nullptr;
819 }
820 }
821
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)822 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
823 std::vector<std::pair<string, Tensor>>* input_pairs,
824 TF_Status* status) {
825 const int ninputs = input_pairs->size();
826 for (int i = 0; i < ninputs; ++i) {
827 status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
828 if (TF_GetCode(status) != TF_OK) return false;
829 }
830 return true;
831 }
832
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)833 static void TF_Run_Helper(
834 Session* session, const char* handle, const TF_Buffer* run_options,
835 // Input tensors
836 const std::vector<std::pair<string, Tensor>>& input_pairs,
837 // Output tensors
838 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
839 // Target nodes
840 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
841 TF_Status* status) {
842 const int noutputs = output_tensor_names.size();
843 std::vector<Tensor> outputs(noutputs);
844 Status result;
845
846 if (handle == nullptr) {
847 RunOptions run_options_proto;
848 if (run_options != nullptr && !run_options_proto.ParseFromArray(
849 run_options->data, run_options->length)) {
850 status->status = InvalidArgument("Unparseable RunOptions proto");
851 return;
852 }
853 if (run_metadata != nullptr && run_metadata->data != nullptr) {
854 status->status =
855 InvalidArgument("Passing non-empty run_metadata is invalid.");
856 return;
857 }
858
859 RunMetadata run_metadata_proto;
860 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
861 target_oper_names, &outputs, &run_metadata_proto);
862
863 // Serialize back to upstream client, who now owns the new buffer
864 if (run_metadata != nullptr) {
865 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
866 if (TF_GetCode(status) != TF_OK) return;
867 }
868 } else {
869 // NOTE(zongheng): PRun does not support RunOptions yet.
870 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
871 }
872 if (!result.ok()) {
873 status->status = result;
874 return;
875 }
876
877 // Store results in c_outputs[]
878 for (int i = 0; i < noutputs; ++i) {
879 const Tensor& src = outputs[i];
880 if (!src.IsInitialized() || src.NumElements() == 0) {
881 c_outputs[i] =
882 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
883 continue;
884 }
885 c_outputs[i] = TF_TensorFromTensor(src, status);
886 if (TF_GetCode(status) != TF_OK) return;
887 }
888 }
889
890 extern "C" {
891
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)892 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
893 // Input tensors
894 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
895 // Output tensors
896 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
897 // Target nodes
898 const char** c_target_oper_names, int ntargets,
899 TF_Buffer* run_metadata, TF_Status* status) {
900 TF_Run_Setup(noutputs, c_outputs, status);
901 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
902 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
903 for (int i = 0; i < ninputs; ++i) {
904 input_pairs[i].first = c_input_names[i];
905 }
906 std::vector<string> output_names(noutputs);
907 for (int i = 0; i < noutputs; ++i) {
908 output_names[i] = c_output_names[i];
909 }
910 std::vector<string> target_oper_names(ntargets);
911 for (int i = 0; i < ntargets; ++i) {
912 target_oper_names[i] = c_target_oper_names[i];
913 }
914 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
915 c_outputs, target_oper_names, run_metadata, status);
916 }
917
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)918 void TF_PRunSetup(TF_DeprecatedSession* s,
919 // Input names
920 const char** c_input_names, int ninputs,
921 // Output names
922 const char** c_output_names, int noutputs,
923 // Target nodes
924 const char** c_target_oper_names, int ntargets,
925 const char** handle, TF_Status* status) {
926 *handle = nullptr;
927
928 std::vector<string> input_names(ninputs);
929 std::vector<string> output_names(noutputs);
930 std::vector<string> target_oper_names(ntargets);
931 for (int i = 0; i < ninputs; ++i) {
932 input_names[i] = c_input_names[i];
933 }
934 for (int i = 0; i < noutputs; ++i) {
935 output_names[i] = c_output_names[i];
936 }
937 for (int i = 0; i < ntargets; ++i) {
938 target_oper_names[i] = c_target_oper_names[i];
939 }
940 string new_handle;
941 status->status = s->session->PRunSetup(input_names, output_names,
942 target_oper_names, &new_handle);
943 if (TF_GetCode(status) == TF_OK) {
944 char* buf = new char[new_handle.size() + 1];
945 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
946 *handle = buf;
947 }
948 }
949
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)950 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
951 // Input tensors
952 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
953 // Output tensors
954 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
955 // Target nodes
956 const char** c_target_oper_names, int ntargets,
957 TF_Status* status) {
958 TF_Run_Setup(noutputs, c_outputs, status);
959 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
960 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
961 for (int i = 0; i < ninputs; ++i) {
962 input_pairs[i].first = c_input_names[i];
963 }
964
965 std::vector<string> output_names(noutputs);
966 for (int i = 0; i < noutputs; ++i) {
967 output_names[i] = c_output_names[i];
968 }
969 std::vector<string> target_oper_names(ntargets);
970 for (int i = 0; i < ntargets; ++i) {
971 target_oper_names[i] = c_target_oper_names[i];
972 }
973 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
974 c_outputs, target_oper_names, nullptr, status);
975 }
976
TF_LoadLibrary(const char * library_filename,TF_Status * status)977 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
978 TF_Library* lib_handle = new TF_Library;
979 status->status = tensorflow::LoadLibrary(
980 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
981 &lib_handle->op_list.length);
982 if (TF_GetCode(status) != TF_OK) {
983 delete lib_handle;
984 return nullptr;
985 }
986 return lib_handle;
987 }
988
TF_GetOpList(TF_Library * lib_handle)989 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
990
TF_DeleteLibraryHandle(TF_Library * lib_handle)991 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
992 if (lib_handle == nullptr) return;
993 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
994 delete lib_handle;
995 }
996
TF_GetAllOpList()997 TF_Buffer* TF_GetAllOpList() {
998 std::vector<tensorflow::OpDef> op_defs;
999 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
1000 tensorflow::OpList op_list;
1001 for (const auto& op : op_defs) {
1002 *(op_list.add_op()) = op;
1003 }
1004 TF_Buffer* ret = TF_NewBuffer();
1005 TF_CHECK_OK(MessageToBuffer(op_list, ret));
1006 return ret;
1007 }
1008
1009 // --------------------------------------------------------------------------
1010 // ListDevices & SessionListDevices API
1011
TF_DeleteDeviceList(TF_DeviceList * s)1012 void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
1013
TF_SessionListDevices(TF_Session * session,TF_Status * status)1014 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
1015 TF_DeviceList* response = new TF_DeviceList;
1016 status->status = session->session->ListDevices(&response->response);
1017 return response;
1018 }
1019
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)1020 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
1021 TF_Status* status) {
1022 TF_DeviceList* response = new TF_DeviceList;
1023 status->status = session->session->ListDevices(&response->response);
1024 return response;
1025 }
1026
TF_DeviceListCount(const TF_DeviceList * list)1027 int TF_DeviceListCount(const TF_DeviceList* list) {
1028 return list->response.size();
1029 }
1030
1031 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
1032 return_type method_name(const TF_DeviceList* list, const int index, \
1033 TF_Status* status) { \
1034 if (list == nullptr) { \
1035 status->status = InvalidArgument("list is null!"); \
1036 return err_val; \
1037 } \
1038 if (index < 0 || index >= list->response.size()) { \
1039 status->status = InvalidArgument("index out of bounds"); \
1040 return err_val; \
1041 } \
1042 status->status = Status::OK(); \
1043 return list->response[index].accessor; \
1044 }
1045
1046 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
1047 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
1048 nullptr);
1049 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
1050 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
1051
1052 #undef TF_DEVICELIST_METHOD
1053
1054 } // end extern "C"
1055
1056 // --------------------------------------------------------------------------
1057 // New Graph and Session API
1058
1059 // Helper functions -----------------------------------------------------------
1060
1061 namespace {
1062
ToOperation(Node * node)1063 TF_Operation* ToOperation(Node* node) {
1064 return static_cast<TF_Operation*>(static_cast<void*>(node));
1065 }
1066
OutputName(const TF_Output & output)1067 string OutputName(const TF_Output& output) {
1068 return StrCat(output.oper->node.name(), ":", output.index);
1069 }
1070
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)1071 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
1072 const char* attr_name,
1073 TF_Status* status) {
1074 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
1075 if (attr == nullptr) {
1076 status->status = InvalidArgument("Operation '", oper->node.name(),
1077 "' has no attr named '", attr_name, "'.");
1078 }
1079 return attr;
1080 }
1081
ToTensorId(const TF_Output & output)1082 TensorId ToTensorId(const TF_Output& output) {
1083 return TensorId(output.oper->node.name(), output.index);
1084 }
1085
1086 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)1087 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
1088 int n) {
1089 std::vector<tensorflow::Output> outputs(n);
1090 for (int i = 0; i < n; ++i) {
1091 outputs[i] =
1092 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
1093 }
1094 return outputs;
1095 }
1096
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)1097 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
1098 TF_Output* tf_outputs) {
1099 for (int i = 0; i < outputs.size(); i++) {
1100 tf_outputs[i].oper = ToOperation(outputs[i].node());
1101 tf_outputs[i].index = outputs[i].index();
1102 }
1103 }
1104 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1105
1106 } // namespace
1107
1108 // Shape functions -----------------------------------------------------------
1109
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)1110 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1111 const int64_t* dims, const int num_dims,
1112 TF_Status* status) {
1113 Node* node = &output.oper->node;
1114
1115 mutex_lock l(graph->mu);
1116 tensorflow::shape_inference::InferenceContext* ic =
1117 graph->refiner.GetContext(node);
1118 if (ic == nullptr) {
1119 status->status =
1120 InvalidArgument("Node ", node->name(), " was not found in the graph");
1121 return;
1122 }
1123 tensorflow::shape_inference::ShapeHandle new_shape =
1124 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1125 status->status = graph->refiner.SetShape(node, output.index, new_shape);
1126 }
1127
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)1128 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1129 TF_Status* status) {
1130 Node* node = &output.oper->node;
1131
1132 mutex_lock l(graph->mu);
1133 tensorflow::shape_inference::InferenceContext* ic =
1134 graph->refiner.GetContext(node);
1135 if (ic == nullptr) {
1136 status->status =
1137 InvalidArgument("Node ", node->name(), " was not found in the graph");
1138 return -1;
1139 }
1140
1141 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1142
1143 // Unknown rank means the number of dimensions is -1.
1144 if (!ic->RankKnown(shape)) {
1145 return -1;
1146 }
1147
1148 return ic->Rank(shape);
1149 }
1150
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)1151 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1152 int num_dims, TF_Status* status) {
1153 Node* node = &output.oper->node;
1154
1155 mutex_lock l(graph->mu);
1156 tensorflow::shape_inference::InferenceContext* ic =
1157 graph->refiner.GetContext(node);
1158 if (ic == nullptr) {
1159 status->status =
1160 InvalidArgument("Node ", node->name(), " was not found in the graph");
1161 return;
1162 }
1163
1164 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1165
1166 int rank = -1;
1167 if (ic->RankKnown(shape)) {
1168 rank = ic->Rank(shape);
1169 }
1170
1171 if (num_dims != rank) {
1172 status->status = InvalidArgument("Expected rank is ", num_dims,
1173 " but actual rank is ", rank);
1174 return;
1175 }
1176
1177 if (num_dims == 0) {
1178 // Output shape is a scalar.
1179 return;
1180 }
1181
1182 // Rank is greater than 0, so fill in the values, if known, and
1183 // -1 for unknown values.
1184 for (int i = 0; i < num_dims; ++i) {
1185 auto dim = ic->Dim(shape, i);
1186 tensorflow::int64 value = -1;
1187 if (ic->ValueKnown(dim)) {
1188 value = ic->Value(dim);
1189 }
1190 dims[i] = value;
1191 }
1192 }
1193
1194 // TF_OperationDescription functions ------------------------------------------
1195
1196 extern "C" {
1197
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)1198 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1199 const char* op_type,
1200 const char* oper_name)
1201 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1202 return new TF_OperationDescription(graph, op_type, oper_name);
1203 }
1204
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)1205 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1206 const char* oper_name) {
1207 mutex_lock l(graph->mu);
1208 return TF_NewOperationLocked(graph, op_type, oper_name);
1209 }
1210
TF_SetDevice(TF_OperationDescription * desc,const char * device)1211 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1212 desc->node_builder.Device(device);
1213 }
1214
TF_AddInput(TF_OperationDescription * desc,TF_Output input)1215 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1216 desc->node_builder.Input(&input.oper->node, input.index);
1217 }
1218
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)1219 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1220 int num_inputs) {
1221 std::vector<NodeBuilder::NodeOut> input_list;
1222 input_list.reserve(num_inputs);
1223 for (int i = 0; i < num_inputs; ++i) {
1224 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1225 }
1226 desc->node_builder.Input(input_list);
1227 }
1228
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)1229 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1230 desc->node_builder.ControlInput(&input->node);
1231 }
1232
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)1233 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1234 desc->colocation_constraints.emplace(
1235 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1236 }
1237
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)1238 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1239 const void* value, size_t length) {
1240 tensorflow::StringPiece s(static_cast<const char*>(value), length);
1241 desc->node_builder.Attr(attr_name, s);
1242 }
1243
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1244 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1245 const void* const* values, const size_t* lengths,
1246 int num_values) {
1247 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1248 desc->colocation_constraints.clear();
1249 for (int i = 0; i < num_values; ++i) {
1250 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1251 lengths[i]);
1252 }
1253 } else {
1254 std::vector<tensorflow::StringPiece> v;
1255 v.reserve(num_values);
1256 for (int i = 0; i < num_values; ++i) {
1257 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1258 }
1259 desc->node_builder.Attr(attr_name, v);
1260 }
1261 }
1262
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)1263 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1264 int64_t value) {
1265 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1266 "64-bit int types should match in size");
1267 desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1268 }
1269
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)1270 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1271 const int64_t* values, int num_values) {
1272 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1273 "64-bit int types should match in size");
1274 desc->node_builder.Attr(
1275 attr_name,
1276 ArraySlice<const tensorflow::int64>(
1277 reinterpret_cast<const tensorflow::int64*>(values), num_values));
1278 }
1279
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)1280 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1281 float value) {
1282 desc->node_builder.Attr(attr_name, value);
1283 }
1284
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)1285 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1286 const float* values, int num_values) {
1287 desc->node_builder.Attr(attr_name,
1288 ArraySlice<const float>(values, num_values));
1289 }
1290
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)1291 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1292 unsigned char value) {
1293 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1294 }
1295
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)1296 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1297 const unsigned char* values, int num_values) {
1298 std::unique_ptr<bool[]> b(new bool[num_values]);
1299 for (int i = 0; i < num_values; ++i) {
1300 b[i] = values[i];
1301 }
1302 desc->node_builder.Attr(attr_name,
1303 ArraySlice<const bool>(b.get(), num_values));
1304 }
1305
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)1306 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1307 TF_DataType value) {
1308 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1309 }
1310
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)1311 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1312 const TF_DataType* values, int num_values) {
1313 desc->node_builder.Attr(
1314 attr_name, ArraySlice<const DataType>(
1315 reinterpret_cast<const DataType*>(values), num_values));
1316 }
1317
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)1318 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
1319 const char* placeholder) {
1320 tensorflow::AttrValue attr_value;
1321 attr_value.set_placeholder(placeholder);
1322 desc->node_builder.Attr(attr_name, attr_value);
1323 }
1324
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)1325 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1326 const char* value, size_t length) {
1327 tensorflow::NameAttrList func_name;
1328 func_name.set_name(string(value, value + length));
1329 desc->node_builder.Attr(attr_name, func_name);
1330 }
1331
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)1332 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1333 const int64_t* dims, int num_dims) {
1334 PartialTensorShape shape;
1335 if (num_dims >= 0) {
1336 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1337 "64-bit int types should match in size");
1338 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1339 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1340 }
1341 desc->node_builder.Attr(attr_name, shape);
1342 }
1343
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)1344 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1345 const int64_t* const* dims, const int* num_dims,
1346 int num_shapes) {
1347 std::vector<PartialTensorShape> shapes;
1348 shapes.reserve(num_shapes);
1349 for (int i = 0; i < num_shapes; ++i) {
1350 if (num_dims[i] < 0) {
1351 shapes.emplace_back();
1352 } else {
1353 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1354 "64-bit int types should match in size");
1355 shapes.emplace_back(ArraySlice<tensorflow::int64>(
1356 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1357 }
1358 }
1359 desc->node_builder.Attr(attr_name, shapes);
1360 }
1361
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1362 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1363 const char* attr_name, const void* proto,
1364 size_t proto_len, TF_Status* status) {
1365 // shape.ParseFromArray takes an int as length, this function takes size_t,
1366 // make sure there is no information loss.
1367 if (proto_len > std::numeric_limits<int>::max()) {
1368 status->status = InvalidArgument(
1369 "proto_len (", proto_len,
1370 " bytes) is too large to be parsed by the protocol buffer library");
1371 return;
1372 }
1373 TensorShapeProto shape;
1374 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1375 desc->node_builder.Attr(attr_name, shape);
1376 status->status = Status::OK();
1377 } else {
1378 status->status = InvalidArgument("Unparseable TensorShapeProto");
1379 }
1380 }
1381
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)1382 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1383 const char* attr_name,
1384 const void* const* protos,
1385 const size_t* proto_lens, int num_shapes,
1386 TF_Status* status) {
1387 std::vector<TensorShapeProto> shapes;
1388 shapes.resize(num_shapes);
1389 for (int i = 0; i < num_shapes; ++i) {
1390 if (proto_lens[i] > std::numeric_limits<int>::max()) {
1391 status->status = InvalidArgument(
1392 "length of element ", i, " in the list (", proto_lens[i],
1393 " bytes) is too large to be parsed by the protocol buffer library");
1394 return;
1395 }
1396 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1397 status->status =
1398 InvalidArgument("Unparseable TensorShapeProto at index ", i);
1399 return;
1400 }
1401 }
1402 desc->node_builder.Attr(attr_name, shapes);
1403 status->status = Status::OK();
1404 }
1405
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1406 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1407 TF_Tensor* value, TF_Status* status) {
1408 Tensor t;
1409 status->status = TF_TensorToTensor(value, &t);
1410 if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
1411 }
1412
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1413 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1414 TF_Tensor* const* values, int num_values,
1415 TF_Status* status) {
1416 status->status = Status::OK();
1417 std::vector<Tensor> t;
1418 t.reserve(num_values);
1419
1420 for (int i = 0; i < num_values && TF_GetCode(status) == TF_OK; ++i) {
1421 Tensor v;
1422 status->status = TF_TensorToTensor(values[i], &v);
1423 t.emplace_back(v);
1424 }
1425
1426 if (TF_GetCode(status) == TF_OK) desc->node_builder.Attr(attr_name, t);
1427 }
1428
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1429 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1430 const void* proto, size_t proto_len,
1431 TF_Status* status) {
1432 tensorflow::AttrValue attr_value;
1433 if (!attr_value.ParseFromArray(proto, proto_len)) {
1434 status->status = InvalidArgument("Unparseable AttrValue proto");
1435 return;
1436 }
1437
1438 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1439 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1440 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1441 status->status =
1442 InvalidArgument("Expected \"list\" field for \"",
1443 tensorflow::kColocationAttrName, "\" attribute");
1444 return;
1445 }
1446 desc->colocation_constraints.clear();
1447 for (const string& location : attr_value.list().s()) {
1448 desc->colocation_constraints.insert(location);
1449 }
1450 } else {
1451 desc->node_builder.Attr(attr_name, attr_value);
1452 }
1453
1454 status->status = Status::OK();
1455 }
1456
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1457 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1458 TF_Status* status)
1459 EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1460 Node* ret = nullptr;
1461
1462 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1463 status->status = InvalidArgument("Duplicate node name in graph: '",
1464 desc->node_builder.node_name(), "'");
1465 } else {
1466 if (!desc->colocation_constraints.empty()) {
1467 desc->node_builder.Attr(
1468 tensorflow::kColocationAttrName,
1469 std::vector<string>(desc->colocation_constraints.begin(),
1470 desc->colocation_constraints.end()));
1471 }
1472 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1473
1474 if (TF_GetCode(status) == TF_OK) {
1475 // Run shape inference function for newly added node.
1476 status->status = desc->graph->refiner.AddNode(ret);
1477 }
1478 if (TF_GetCode(status) == TF_OK) {
1479 // Add the node to the name-to-node mapping.
1480 desc->graph->name_map[ret->name()] = ret;
1481 } else if (ret != nullptr) {
1482 desc->graph->graph.RemoveNode(ret);
1483 ret = nullptr;
1484 }
1485 }
1486
1487 delete desc;
1488
1489 return ToOperation(ret);
1490 }
1491
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1492 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1493 TF_Status* status) {
1494 mutex_lock l(desc->graph->mu);
1495 return TF_FinishOperationLocked(desc, status);
1496 }
1497
1498 // TF_Operation functions
1499 // ----------------------------------------------------------
1500
TF_OperationName(TF_Operation * oper)1501 const char* TF_OperationName(TF_Operation* oper) {
1502 return oper->node.name().c_str();
1503 }
1504
TF_OperationOpType(TF_Operation * oper)1505 const char* TF_OperationOpType(TF_Operation* oper) {
1506 return oper->node.type_string().c_str();
1507 }
1508
TF_OperationDevice(TF_Operation * oper)1509 const char* TF_OperationDevice(TF_Operation* oper) {
1510 return oper->node.requested_device().c_str();
1511 }
1512
TF_OperationNumOutputs(TF_Operation * oper)1513 int TF_OperationNumOutputs(TF_Operation* oper) {
1514 return oper->node.num_outputs();
1515 }
1516
TF_OperationOutputType(TF_Output oper_out)1517 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1518 return static_cast<TF_DataType>(
1519 oper_out.oper->node.output_type(oper_out.index));
1520 }
1521
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1522 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1523 TF_Status* status) {
1524 NameRangeMap name_ranges;
1525 status->status =
1526 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1527 if (TF_GetCode(status) != TF_OK) return -1;
1528 auto iter = name_ranges.find(arg_name);
1529 if (iter == name_ranges.end()) {
1530 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1531 return -1;
1532 }
1533 return iter->second.second - iter->second.first;
1534 }
1535
TF_OperationNumInputs(TF_Operation * oper)1536 int TF_OperationNumInputs(TF_Operation* oper) {
1537 return oper->node.num_inputs();
1538 }
1539
TF_OperationInputType(TF_Input oper_in)1540 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1541 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1542 }
1543
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1544 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1545 TF_Status* status) {
1546 NameRangeMap name_ranges;
1547 status->status =
1548 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1549 if (TF_GetCode(status) != TF_OK) return -1;
1550 auto iter = name_ranges.find(arg_name);
1551 if (iter == name_ranges.end()) {
1552 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1553 return -1;
1554 }
1555 return iter->second.second - iter->second.first;
1556 }
1557
TF_OperationInput(TF_Input oper_in)1558 TF_Output TF_OperationInput(TF_Input oper_in) {
1559 const tensorflow::Edge* edge;
1560 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1561 if (!s.ok()) {
1562 return {nullptr, -1};
1563 }
1564
1565 return {ToOperation(edge->src()), edge->src_output()};
1566 }
1567
TF_OperationOutputNumConsumers(TF_Output oper_out)1568 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1569 int count = 0;
1570 for (const auto* edge : oper_out.oper->node.out_edges()) {
1571 if (edge->src_output() == oper_out.index) {
1572 ++count;
1573 }
1574 }
1575 return count;
1576 }
1577
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1578 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1579 int max_consumers) {
1580 int count = 0;
1581 for (const auto* edge : oper_out.oper->node.out_edges()) {
1582 if (edge->src_output() == oper_out.index) {
1583 if (count < max_consumers) {
1584 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1585 }
1586 ++count;
1587 }
1588 }
1589 return count;
1590 }
1591
TF_OperationNumControlInputs(TF_Operation * oper)1592 int TF_OperationNumControlInputs(TF_Operation* oper) {
1593 int count = 0;
1594 for (const auto* edge : oper->node.in_edges()) {
1595 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1596 ++count;
1597 }
1598 }
1599 return count;
1600 }
1601
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1602 int TF_OperationGetControlInputs(TF_Operation* oper,
1603 TF_Operation** control_inputs,
1604 int max_control_inputs) {
1605 int count = 0;
1606 for (const auto* edge : oper->node.in_edges()) {
1607 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1608 if (count < max_control_inputs) {
1609 control_inputs[count] = ToOperation(edge->src());
1610 }
1611 ++count;
1612 }
1613 }
1614 return count;
1615 }
1616
TF_OperationNumControlOutputs(TF_Operation * oper)1617 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1618 int count = 0;
1619 for (const auto* edge : oper->node.out_edges()) {
1620 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1621 ++count;
1622 }
1623 }
1624 return count;
1625 }
1626
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1627 int TF_OperationGetControlOutputs(TF_Operation* oper,
1628 TF_Operation** control_outputs,
1629 int max_control_outputs) {
1630 int count = 0;
1631 for (const auto* edge : oper->node.out_edges()) {
1632 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1633 if (count < max_control_outputs) {
1634 control_outputs[count] = ToOperation(edge->dst());
1635 }
1636 ++count;
1637 }
1638 }
1639 return count;
1640 }
1641
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1642 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1643 const char* attr_name,
1644 TF_Status* status) {
1645 TF_AttrMetadata metadata;
1646 const auto* attr = GetAttrValue(oper, attr_name, status);
1647 if (TF_GetCode(status) != TF_OK) return metadata;
1648 switch (attr->value_case()) {
1649 #define SINGLE_CASE(kK, attr_type, size_expr) \
1650 case tensorflow::AttrValue::kK: \
1651 metadata.is_list = 0; \
1652 metadata.list_size = -1; \
1653 metadata.type = attr_type; \
1654 metadata.total_size = size_expr; \
1655 break;
1656
1657 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1658 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1659 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1660 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1661 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1662 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1663 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1664 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1665 #undef SINGLE_CASE
1666
1667 case tensorflow::AttrValue::kList:
1668 metadata.is_list = 1;
1669 metadata.list_size = 0;
1670 metadata.total_size = -1;
1671 #define LIST_CASE(field, attr_type, ...) \
1672 if (attr->list().field##_size() > 0) { \
1673 metadata.type = attr_type; \
1674 metadata.list_size = attr->list().field##_size(); \
1675 __VA_ARGS__; \
1676 break; \
1677 }
1678
1679 LIST_CASE(
1680 s, TF_ATTR_STRING, metadata.total_size = 0;
1681 for (int i = 0; i < attr->list().s_size();
1682 ++i) { metadata.total_size += attr->list().s(i).size(); });
1683 LIST_CASE(i, TF_ATTR_INT);
1684 LIST_CASE(f, TF_ATTR_FLOAT);
1685 LIST_CASE(b, TF_ATTR_BOOL);
1686 LIST_CASE(type, TF_ATTR_TYPE);
1687 LIST_CASE(
1688 shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1689 for (int i = 0; i < attr->list().shape_size(); ++i) {
1690 const auto& s = attr->list().shape(i);
1691 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1692 });
1693 LIST_CASE(tensor, TF_ATTR_TENSOR);
1694 LIST_CASE(tensor, TF_ATTR_FUNC);
1695 #undef LIST_CASE
1696 // All lists empty, determine the type from the OpDef.
1697 if (metadata.list_size == 0) {
1698 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1699 const auto& a = oper->node.op_def().attr(i);
1700 if (a.name().compare(attr_name) != 0) continue;
1701 const string& typestr = a.type();
1702 if (typestr == "list(string)") {
1703 metadata.type = TF_ATTR_STRING;
1704 } else if (typestr == "list(int)") {
1705 metadata.type = TF_ATTR_INT;
1706 } else if (typestr == "list(float)") {
1707 metadata.type = TF_ATTR_FLOAT;
1708 } else if (typestr == "list(bool)") {
1709 metadata.type = TF_ATTR_BOOL;
1710 } else if (typestr == "list(type)") {
1711 metadata.type = TF_ATTR_TYPE;
1712 } else if (typestr == "list(shape)") {
1713 metadata.type = TF_ATTR_SHAPE;
1714 } else if (typestr == "list(tensor)") {
1715 metadata.type = TF_ATTR_TENSOR;
1716 } else if (typestr == "list(func)") {
1717 metadata.type = TF_ATTR_FUNC;
1718 } else {
1719 status->status = InvalidArgument(
1720 "Attribute '", attr_name,
1721 "' has an empty value of an unrecognized type '", typestr, "'");
1722 return metadata;
1723 }
1724 }
1725 }
1726 break;
1727
1728 case tensorflow::AttrValue::kPlaceholder:
1729 metadata.is_list = 0;
1730 metadata.list_size = -1;
1731 metadata.type = TF_ATTR_PLACEHOLDER;
1732 metadata.total_size = -1;
1733 break;
1734
1735 case tensorflow::AttrValue::kFunc:
1736 metadata.is_list = 0;
1737 metadata.list_size = -1;
1738 metadata.type = TF_ATTR_FUNC;
1739 metadata.total_size = -1;
1740 break;
1741
1742 case tensorflow::AttrValue::VALUE_NOT_SET:
1743 status->status =
1744 InvalidArgument("Attribute '", attr_name, "' has no value set");
1745 break;
1746 }
1747 return metadata;
1748 }
1749
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1750 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1751 void* value, size_t max_length,
1752 TF_Status* status) {
1753 const auto* attr = GetAttrValue(oper, attr_name, status);
1754 if (TF_GetCode(status) != TF_OK) return;
1755 if (attr->value_case() != tensorflow::AttrValue::kS) {
1756 status->status =
1757 InvalidArgument("Attribute '", attr_name, "' is not a string");
1758 return;
1759 }
1760 if (max_length <= 0) {
1761 return;
1762 }
1763 const auto& s = attr->s();
1764 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1765 }
1766
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1767 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1768 void** values, size_t* lengths,
1769 int max_values, void* storage,
1770 size_t storage_size, TF_Status* status) {
1771 const auto* attr = GetAttrValue(oper, attr_name, status);
1772 if (TF_GetCode(status) != TF_OK) return;
1773 if (attr->value_case() != tensorflow::AttrValue::kList) {
1774 status->status =
1775 InvalidArgument("Value for '", attr_name, "' is not a list");
1776 return;
1777 }
1778 const auto len = std::min(max_values, attr->list().s_size());
1779 char* p = static_cast<char*>(storage);
1780 for (int i = 0; i < len; ++i) {
1781 const string& s = attr->list().s(i);
1782 values[i] = p;
1783 lengths[i] = s.size();
1784 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1785 status->status = InvalidArgument(
1786 "Not enough storage to hold the requested list of strings");
1787 return;
1788 }
1789 memcpy(values[i], s.data(), s.size());
1790 p += s.size();
1791 }
1792 }
1793
1794 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1795 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1796 TF_Status* status) { \
1797 cpp_type v; \
1798 status->status = \
1799 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1800 *value = static_cast<c_type>(v); \
1801 } \
1802 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1803 int max_values, TF_Status* status) { \
1804 const auto* attr = GetAttrValue(oper, attr_name, status); \
1805 if (TF_GetCode(status) != TF_OK) return; \
1806 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1807 status->status = \
1808 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1809 return; \
1810 } \
1811 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1812 for (int i = 0; i < len; ++i) { \
1813 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1814 } \
1815 }
1816 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1817 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1818 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1819 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1820 #undef DEFINE_GETATTR
1821
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1822 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1823 int64_t* value, int num_dims, TF_Status* status) {
1824 PartialTensorShape shape;
1825 status->status =
1826 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1827 if (TF_GetCode(status) != TF_OK) return;
1828 auto len = std::min(shape.dims(), num_dims);
1829 for (int i = 0; i < len; ++i) {
1830 value[i] = shape.dim_size(i);
1831 }
1832 }
1833
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** values,int * num_dims,int max_values,int64_t * storage,int storage_size,TF_Status * status)1834 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1835 int64_t** values, int* num_dims,
1836 int max_values, int64_t* storage,
1837 int storage_size, TF_Status* status) {
1838 std::vector<PartialTensorShape> shapes;
1839 status->status =
1840 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1841 if (TF_GetCode(status) != TF_OK) return;
1842 auto len = std::min(static_cast<int>(shapes.size()), max_values);
1843 int64_t* p = storage;
1844 int storage_left = storage_size;
1845 for (int i = 0; i < len; ++i) {
1846 // shapes[i].dims() == -1 for shapes with an unknown rank.
1847 int64_t n = shapes[i].dims();
1848 num_dims[i] = n;
1849 values[i] = p;
1850 if (n < 0) {
1851 continue;
1852 }
1853 if (storage_left < n) {
1854 status->status = InvalidArgument(
1855 "Not enough storage to hold the requested list of shapes");
1856 return;
1857 }
1858 storage_left -= n;
1859 for (int j = 0; j < n; ++j, ++p) {
1860 *p = shapes[i].dim_size(j);
1861 }
1862 }
1863 }
1864
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1865 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1866 const char* attr_name,
1867 TF_Buffer* value, TF_Status* status) {
1868 const auto* attr = GetAttrValue(oper, attr_name, status);
1869 if (TF_GetCode(status) != TF_OK) return;
1870 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1871 status->status =
1872 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1873 return;
1874 }
1875 status->status = MessageToBuffer(attr->shape(), value);
1876 }
1877
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1878 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1879 const char* attr_name,
1880 TF_Buffer** values, int max_values,
1881 TF_Status* status) {
1882 const auto* attr = GetAttrValue(oper, attr_name, status);
1883 if (TF_GetCode(status) != TF_OK) return;
1884 if (attr->value_case() != tensorflow::AttrValue::kList) {
1885 status->status =
1886 InvalidArgument("Value for '", attr_name, "' is not a list");
1887 return;
1888 }
1889 const auto len = std::min(max_values, attr->list().shape_size());
1890 for (int i = 0; i < len; ++i) {
1891 values[i] = TF_NewBuffer();
1892 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1893 if (TF_GetCode(status) != TF_OK) {
1894 // Delete everything allocated to far, the operation has failed.
1895 for (int j = 0; j <= i; ++j) {
1896 TF_DeleteBuffer(values[j]);
1897 }
1898 return;
1899 }
1900 }
1901 }
1902
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1903 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1904 TF_Tensor** value, TF_Status* status) {
1905 *value = nullptr;
1906 Tensor t;
1907 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1908 if (TF_GetCode(status) != TF_OK) return;
1909 *value = TF_TensorFromTensor(t, status);
1910 }
1911
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1912 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1913 TF_Tensor** values, int max_values,
1914 TF_Status* status) {
1915 std::vector<Tensor> ts;
1916 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1917 if (TF_GetCode(status) != TF_OK) return;
1918 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1919 for (int i = 0; i < len; ++i) {
1920 values[i] = TF_TensorFromTensor(ts[i], status);
1921 }
1922 }
1923
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1924 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1925 TF_Buffer* output_attr_value,
1926 TF_Status* status) {
1927 const auto* attr = GetAttrValue(oper, attr_name, status);
1928 if (TF_GetCode(status) != TF_OK) return;
1929 status->status = MessageToBuffer(*attr, output_attr_value);
1930 }
1931
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1932 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1933 TF_Status* status) {
1934 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1935 }
1936
1937 // TF_Graph functions ---------------------------------------------------------
1938
TF_Graph()1939 TF_Graph::TF_Graph()
1940 : graph(tensorflow::OpRegistry::Global()),
1941 refiner(graph.versions().producer(), graph.op_registry()),
1942 delete_requested(false),
1943 parent(nullptr),
1944 parent_inputs(nullptr) {}
1945
TF_NewGraph()1946 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1947
TF_DeleteGraph(TF_Graph * g)1948 void TF_DeleteGraph(TF_Graph* g) {
1949 if (g == nullptr) return;
1950 g->mu.lock();
1951 g->delete_requested = true;
1952 const bool del = g->sessions.empty();
1953 g->mu.unlock();
1954 if (del) delete g;
1955 }
1956
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1957 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1958 mutex_lock l(graph->mu);
1959 auto iter = graph->name_map.find(oper_name);
1960 if (iter == graph->name_map.end()) {
1961 return nullptr;
1962 } else {
1963 return ToOperation(iter->second);
1964 }
1965 }
1966
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1967 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1968 if (*pos == 0) {
1969 // Advance past the first sentinel nodes in every graph (the source & sink).
1970 *pos += 2;
1971 } else {
1972 // Advance to the next node.
1973 *pos += 1;
1974 }
1975
1976 mutex_lock l(graph->mu);
1977 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1978 Node* node = graph->graph.FindNodeId(*pos);
1979 // FindNodeId() returns nullptr for nodes that have been deleted.
1980 // We aren't currently allowing nodes to be deleted, but it is safer
1981 // to still check.
1982 if (node != nullptr) return ToOperation(node);
1983 *pos += 1;
1984 }
1985
1986 // No more nodes.
1987 return nullptr;
1988 }
1989
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1990 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1991 TF_Status* status) {
1992 GraphDef def;
1993 {
1994 mutex_lock l(graph->mu);
1995 graph->graph.ToGraphDef(&def);
1996 }
1997 status->status = MessageToBuffer(def, output_graph_def);
1998 }
1999
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)2000 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
2001 TF_Buffer* output_op_def, TF_Status* status) {
2002 const OpDef* op_def;
2003 {
2004 mutex_lock l(graph->mu);
2005 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
2006 if (TF_GetCode(status) != TF_OK) return;
2007 }
2008 status->status = MessageToBuffer(*op_def, output_op_def);
2009 }
2010
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)2011 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
2012 TF_Status* status) {
2013 VersionDef versions;
2014 {
2015 mutex_lock l(graph->mu);
2016 versions = graph->graph.versions();
2017 }
2018 status->status = MessageToBuffer(versions, output_version_def);
2019 }
2020
TF_NewImportGraphDefOptions()2021 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
2022 return new TF_ImportGraphDefOptions;
2023 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)2024 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
2025 delete opts;
2026 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)2027 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
2028 const char* prefix) {
2029 opts->opts.prefix = prefix;
2030 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)2031 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
2032 const char* device) {
2033 opts->opts.default_device = device;
2034 }
2035
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)2036 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
2037 unsigned char uniquify_names) {
2038 opts->opts.uniquify_names = uniquify_names;
2039 }
2040
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)2041 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
2042 unsigned char uniquify_prefix) {
2043 opts->opts.uniquify_prefix = uniquify_prefix;
2044 }
2045
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)2046 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
2047 const char* src_name,
2048 int src_index, TF_Output dst) {
2049 opts->tensor_id_data.push_back(src_name);
2050 const string& src_name_str = opts->tensor_id_data.back();
2051 // We don't need to store dst's name in tensor_id_data, since `dst` must
2052 // outlive the ImportGraphDef call.
2053 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
2054 }
2055
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)2056 void TF_ImportGraphDefOptionsRemapControlDependency(
2057 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
2058 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
2059 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
2060 }
2061
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)2062 extern void TF_ImportGraphDefOptionsAddControlDependency(
2063 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
2064 opts->opts.control_dependencies.push_back(oper->node.name());
2065 }
2066
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)2067 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
2068 const char* oper_name, int index) {
2069 opts->tensor_id_data.push_back(oper_name);
2070 const string& oper_name_str = opts->tensor_id_data.back();
2071 opts->opts.return_tensors.emplace_back(oper_name_str, index);
2072 }
2073
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)2074 int TF_ImportGraphDefOptionsNumReturnOutputs(
2075 const TF_ImportGraphDefOptions* opts) {
2076 return opts->opts.return_tensors.size();
2077 }
2078
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)2079 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
2080 const char* oper_name) {
2081 opts->opts.return_nodes.push_back(oper_name);
2082 }
2083
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)2084 int TF_ImportGraphDefOptionsNumReturnOperations(
2085 const TF_ImportGraphDefOptions* opts) {
2086 return opts->opts.return_nodes.size();
2087 }
2088
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)2089 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
2090 int* num_outputs,
2091 TF_Output** outputs) {
2092 *num_outputs = results->return_tensors.size();
2093 *outputs = results->return_tensors.data();
2094 }
2095
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)2096 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
2097 int* num_opers,
2098 TF_Operation*** opers) {
2099 *num_opers = results->return_nodes.size();
2100 *opers = results->return_nodes.data();
2101 }
2102
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)2103 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
2104 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
2105 const char*** src_names, int** src_indexes) {
2106 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
2107 *src_names = results->missing_unused_key_names.data();
2108 *src_indexes = results->missing_unused_key_indexes.data();
2109 }
2110
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)2111 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
2112 delete results;
2113 }
2114
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)2115 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
2116 const TF_ImportGraphDefOptions* opts,
2117 TF_ImportGraphDefResults* tf_results,
2118 TF_Status* status)
2119 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2120 const int last_node_id = graph->graph.num_node_ids();
2121 tensorflow::ImportGraphDefResults results;
2122 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2123 &graph->refiner, &results);
2124 if (TF_GetCode(status) != TF_OK) return;
2125
2126 // Add new nodes to name_map
2127 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2128 auto* node = graph->graph.FindNodeId(i);
2129 if (node != nullptr) graph->name_map[node->name()] = node;
2130 }
2131
2132 // Populate return_tensors
2133 DCHECK(tf_results->return_tensors.empty());
2134 tf_results->return_tensors.resize(results.return_tensors.size());
2135 for (int i = 0; i < results.return_tensors.size(); ++i) {
2136 tf_results->return_tensors[i].oper =
2137 ToOperation(results.return_tensors[i].first);
2138 tf_results->return_tensors[i].index = results.return_tensors[i].second;
2139 }
2140
2141 // Populate return_nodes
2142 DCHECK(tf_results->return_nodes.empty());
2143 tf_results->return_nodes.resize(results.return_nodes.size());
2144 for (int i = 0; i < results.return_nodes.size(); ++i) {
2145 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2146 }
2147
2148 // Populate missing unused map keys
2149 DCHECK(tf_results->missing_unused_key_names.empty());
2150 DCHECK(tf_results->missing_unused_key_indexes.empty());
2151 DCHECK(tf_results->missing_unused_key_names_data.empty());
2152
2153 size_t size = results.missing_unused_input_map_keys.size();
2154 tf_results->missing_unused_key_names.resize(size);
2155 tf_results->missing_unused_key_indexes.resize(size);
2156
2157 for (int i = 0; i < size; ++i) {
2158 TensorId id = results.missing_unused_input_map_keys[i];
2159 tf_results->missing_unused_key_names_data.emplace_back(id.first);
2160 tf_results->missing_unused_key_names[i] =
2161 tf_results->missing_unused_key_names_data.back().c_str();
2162 tf_results->missing_unused_key_indexes[i] = id.second;
2163 }
2164 }
2165
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2166 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2167 TF_Graph* graph, const TF_Buffer* graph_def,
2168 const TF_ImportGraphDefOptions* options, TF_Status* status) {
2169 GraphDef def;
2170 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
2171 graph_def->length)) {
2172 status->status = InvalidArgument("Invalid GraphDef");
2173 return nullptr;
2174 }
2175 auto results = new TF_ImportGraphDefResults();
2176 mutex_lock l(graph->mu);
2177 GraphImportGraphDefLocked(graph, def, options, results, status);
2178 if (TF_GetCode(status) != TF_OK) {
2179 delete results;
2180 return nullptr;
2181 }
2182 return results;
2183 }
2184
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)2185 void TF_GraphImportGraphDefWithReturnOutputs(
2186 TF_Graph* graph, const TF_Buffer* graph_def,
2187 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2188 int num_return_outputs, TF_Status* status) {
2189 if (num_return_outputs != options->opts.return_tensors.size()) {
2190 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2191 options->opts.return_tensors.size(),
2192 ", got ", num_return_outputs);
2193 return;
2194 }
2195 if (num_return_outputs > 0 && return_outputs == nullptr) {
2196 status->status = InvalidArgument(
2197 "'return_outputs' must be preallocated to length ", num_return_outputs);
2198 return;
2199 }
2200 GraphDef def;
2201 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
2202 graph_def->length)) {
2203 status->status = InvalidArgument("Invalid GraphDef");
2204 return;
2205 }
2206 TF_ImportGraphDefResults results;
2207 mutex_lock l(graph->mu);
2208 GraphImportGraphDefLocked(graph, def, options, &results, status);
2209 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2210 memcpy(return_outputs, results.return_tensors.data(),
2211 num_return_outputs * sizeof(TF_Output));
2212 }
2213
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2214 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2215 const TF_ImportGraphDefOptions* options,
2216 TF_Status* status) {
2217 TF_ImportGraphDefResults* results =
2218 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2219 TF_DeleteImportGraphDefResults(results);
2220 }
2221
2222 // While loop functions -------------------------------------------------------
2223
2224 namespace {
2225
2226 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2227
2228 // Creates a placeholder representing an input to the cond or body graph.
2229 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)2230 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2231 TF_Output* input, TF_Status* status) {
2232 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2233 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2234 // TODO(skyewm): set placeholder shape
2235 TF_Operation* oper = TF_FinishOperation(desc, status);
2236 if (TF_GetCode(status) != TF_OK) return false;
2237 *input = {oper, 0};
2238 return true;
2239 }
2240
2241 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2242 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
2243 // will be prepended to copied node names. `control_deps` are nodes in
2244 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
2245 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2246 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)2247 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2248 tensorflow::ShapeRefiner* dst_refiner,
2249 const TF_Output* src_inputs,
2250 const std::vector<tensorflow::Output>& dst_inputs,
2251 const string& prefix,
2252 const std::vector<tensorflow::Operation>& control_deps,
2253 const TF_Output* nodes_to_return, int nreturn_nodes,
2254 std::vector<tensorflow::Output>* return_nodes) {
2255 DCHECK(return_nodes != nullptr);
2256 GraphDef gdef;
2257 src_graph->ToGraphDef(&gdef);
2258
2259 tensorflow::ImportGraphDefOptions opts;
2260 opts.prefix = prefix;
2261
2262 for (int i = 0; i < dst_inputs.size(); ++i) {
2263 opts.input_map[ToTensorId(src_inputs[i])] =
2264 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2265 }
2266 opts.skip_mapped_nodes = true;
2267
2268 for (const tensorflow::Operation& op : control_deps) {
2269 opts.control_dependencies.push_back(op.node()->name());
2270 }
2271
2272 for (int i = 0; i < nreturn_nodes; ++i) {
2273 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2274 }
2275
2276 // TODO(skyewm): change to OutputTensor
2277 tensorflow::ImportGraphDefResults results;
2278 TF_RETURN_IF_ERROR(
2279 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2280
2281 for (const auto& pair : results.return_tensors) {
2282 return_nodes->emplace_back(pair.first, pair.second);
2283 }
2284 return Status::OK();
2285 }
2286
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)2287 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2288 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2289 params.cond_graph->parent == nullptr ||
2290 params.cond_graph->parent != params.body_graph->parent ||
2291 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2292 params.ninputs <= 0 || params.cond_inputs == nullptr ||
2293 params.body_inputs == nullptr || params.body_outputs == nullptr) {
2294 s->status = InvalidArgument(
2295 "TF_WhileParams must be created by successful TF_NewWhile() call");
2296 return false;
2297 }
2298 return true;
2299 }
2300
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)2301 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2302 if (params.cond_output.oper == nullptr) {
2303 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2304 return false;
2305 }
2306 for (int i = 0; i < params.ninputs; ++i) {
2307 if (params.body_outputs[i].oper == nullptr) {
2308 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2309 "field isn't set");
2310 return false;
2311 }
2312 }
2313 if (params.name == nullptr) {
2314 s->status = InvalidArgument("TF_WhileParams `name` field is null");
2315 return false;
2316 }
2317 return true;
2318 }
2319
2320 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2321
FreeWhileResources(const TF_WhileParams * params)2322 void FreeWhileResources(const TF_WhileParams* params) {
2323 TF_DeleteGraph(params->cond_graph);
2324 TF_DeleteGraph(params->body_graph);
2325 delete[] params->cond_inputs;
2326 delete[] params->body_inputs;
2327 delete[] params->body_outputs;
2328 }
2329
EmptyWhileParams()2330 TF_WhileParams EmptyWhileParams() {
2331 return {0, nullptr, nullptr, {nullptr, 0},
2332 nullptr, nullptr, nullptr, nullptr};
2333 }
2334
2335 } // namespace
2336
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)2337 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2338 TF_Status* status) {
2339 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2340 status->status = tensorflow::errors::Unimplemented(
2341 "Creating while loops is not supported on mobile. File a bug at "
2342 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2343 "important to you");
2344 return EmptyWhileParams();
2345 #else
2346 if (ninputs == 0) {
2347 status->status =
2348 InvalidArgument("TF_NewWhile() must be passed at least one input");
2349 return EmptyWhileParams();
2350 }
2351
2352 TF_Graph* cond_graph = TF_NewGraph();
2353 TF_Graph* body_graph = TF_NewGraph();
2354 cond_graph->parent = g;
2355 cond_graph->parent_inputs = inputs;
2356 body_graph->parent = g;
2357 body_graph->parent_inputs = inputs;
2358
2359 TF_Output* cond_inputs = new TF_Output[ninputs];
2360 TF_Output cond_output = {nullptr, -1};
2361 TF_Output* body_inputs = new TF_Output[ninputs];
2362 TF_Output* body_outputs = new TF_Output[ninputs];
2363 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2364 const char* name = nullptr;
2365
2366 for (int i = 0; i < ninputs; ++i) {
2367 // TODO(skyewm): prefix names with underscore (requires some plumbing)
2368 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2369 &cond_inputs[i], status)) {
2370 break;
2371 }
2372 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2373 &body_inputs[i], status)) {
2374 break;
2375 }
2376 }
2377
2378 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
2379 body_graph, body_inputs, body_outputs, name};
2380
2381 if (TF_GetCode(status) != TF_OK) {
2382 FreeWhileResources(¶ms);
2383 return EmptyWhileParams();
2384 }
2385 return params;
2386 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2387 }
2388
2389 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2390 namespace {
2391
2392 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2393 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2394 TF_Output* outputs) {
2395 if (!ValidateInputWhileParams(*params, status)) return;
2396
2397 TF_Graph* parent = params->cond_graph->parent;
2398 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2399 int num_loop_vars = params->ninputs;
2400
2401 mutex_lock l(parent->mu);
2402
2403 // 'cond_fn' copies the cond graph into the parent graph.
2404 tensorflow::ops::CondGraphBuilderFn cond_fn =
2405 [params, parent](const tensorflow::Scope& scope,
2406 const std::vector<tensorflow::Output>& inputs,
2407 tensorflow::Output* output) {
2408 DCHECK_EQ(scope.graph(), &parent->graph);
2409 std::vector<tensorflow::Output> cond_output;
2410 TF_RETURN_IF_ERROR(CopyGraph(
2411 ¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
2412 params->cond_inputs, inputs, scope.impl()->name(),
2413 scope.impl()->control_deps(), ¶ms->cond_output,
2414 /* nreturn_nodes */ 1, &cond_output));
2415 *output = cond_output[0];
2416 return Status::OK();
2417 };
2418
2419 // 'body_fn' copies the body graph into the parent graph.
2420 tensorflow::ops::BodyGraphBuilderFn body_fn =
2421 [params, parent, num_loop_vars](
2422 const tensorflow::Scope& scope,
2423 const std::vector<tensorflow::Output>& inputs,
2424 std::vector<tensorflow::Output>* outputs) {
2425 DCHECK_EQ(scope.graph(), &parent->graph);
2426 TF_RETURN_IF_ERROR(
2427 CopyGraph(¶ms->body_graph->graph, &parent->graph,
2428 &parent->refiner, params->body_inputs, inputs,
2429 scope.impl()->name(), scope.impl()->control_deps(),
2430 params->body_outputs, num_loop_vars, outputs));
2431 return Status::OK();
2432 };
2433
2434 // Create the while loop using an internal scope.
2435 tensorflow::Scope scope =
2436 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2437 .NewSubScope(params->name);
2438
2439 const int first_new_node_id = parent->graph.num_node_ids();
2440
2441 tensorflow::OutputList loop_outputs;
2442 status->status = tensorflow::ops::BuildWhileLoop(
2443 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2444 body_fn, params->name, &loop_outputs);
2445
2446 // Update name_map with newly-created ops.
2447 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2448 // a bad status. Once we fix this, we may want to return early instead of
2449 // executing the following code.
2450 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2451 Node* new_node = parent->graph.FindNodeId(i);
2452 if (new_node == nullptr) continue;
2453 parent->name_map[new_node->name()] = new_node;
2454 }
2455
2456 // Populate 'outputs'.
2457 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2458 for (int i = 0; i < loop_outputs.size(); ++i) {
2459 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2460 }
2461 }
2462
2463 } // namespace
2464 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2465
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2466 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2467 TF_Output* outputs) {
2468 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2469 status->status = tensorflow::errors::Unimplemented(
2470 "Creating while loops is not supported on mobile. File a bug at "
2471 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2472 "important to you");
2473 #else
2474 // If it appears the caller created or modified `params`, don't free resources
2475 if (!ValidateConstWhileParams(*params, status)) return;
2476 TF_FinishWhileHelper(params, status, outputs);
2477 FreeWhileResources(params);
2478 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2479 }
2480
TF_AbortWhile(const TF_WhileParams * params)2481 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2482
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2483 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2484 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2485 TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2486 }
2487
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2488 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2489 int ny, TF_Output* x, int nx, TF_Output* dx,
2490 TF_Status* status, TF_Output* dy) {
2491 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2492 status->status = tensorflow::errors::Unimplemented(
2493 "Adding gradients is not supported on mobile. File a bug at "
2494 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2495 "important to you");
2496 #else
2497 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2498 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2499 std::vector<tensorflow::Output> dy_arg;
2500
2501 {
2502 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2503 mutex_lock graph_lock(g->mu);
2504
2505 const int first_new_node_id = g->graph.num_node_ids();
2506
2507 string prefix_cmp;
2508 const char* child_scope_name;
2509 if (prefix == nullptr) {
2510 child_scope_name = "gradients";
2511 } else {
2512 prefix_cmp = string(prefix) + "/";
2513 // The operation should fail if the provided name prefix has already been
2514 // used in this graph
2515 for (const auto& pair : g->name_map) {
2516 const string& name = pair.first;
2517 if (name.compare(prefix) == 0 ||
2518 tensorflow::str_util::StartsWith(name, prefix_cmp)) {
2519 status->status = InvalidArgument(
2520 "prefix [", prefix,
2521 "] conflicts with existing node in the graph named [", name, "]");
2522 return;
2523 }
2524 }
2525 child_scope_name = prefix;
2526 }
2527 tensorflow::Scope scope =
2528 NewInternalScope(&g->graph, &status->status, &g->refiner)
2529 .NewSubScope(child_scope_name);
2530
2531 if (dx != nullptr) {
2532 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2533 status->status =
2534 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2535 } else {
2536 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2537 }
2538
2539 // Update g->name_map with the name_map from the scope, which will contain
2540 // the new gradient ops.
2541 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2542 Node* n = g->graph.FindNodeId(i);
2543 if (n == nullptr) continue;
2544
2545 // Adding the gradients to the graph can alter the prefix to prevent
2546 // name collisions only if this prefix has not been provided explicitly
2547 // by the user. If it was provided, assert that it remained intact.
2548 if (prefix != nullptr &&
2549 !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
2550 status->status = tensorflow::errors::Internal(
2551 "BUG: The gradients prefix have been unexpectedly altered when "
2552 "adding the nodes to the graph. This is a bug. Please file an "
2553 "issue at https://github.com/tensorflow/tensorflow/issues.");
2554 return;
2555 }
2556 // We have a convoluted scheme here: Using the C++ graph construction API
2557 // to add potentially many nodes to the graph without running the checks
2558 // (such as uniqueness of the names of nodes) we run with other functions
2559 // that add a node to the graph (like TF_FinishOperation).
2560 if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2561 status->status = tensorflow::errors::Internal(
2562 "BUG: The API allowed construction of a graph with duplicate node "
2563 "names (",
2564 n->name(),
2565 "). This is a bug. Please file an issue at "
2566 "https://github.com/tensorflow/tensorflow/issues.");
2567 }
2568 }
2569 }
2570
2571 // Unpack the results from grad_outputs_arg.
2572 TFOutputsFromOutputs(dy_arg, dy);
2573 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2574 }
2575
2576 // TF_Session functions ----------------------------------------------
2577
TF_Session(tensorflow::Session * s,TF_Graph * g)2578 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2579 : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2580
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2581 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2582 TF_Status* status) {
2583 Session* session;
2584 status->status = NewSession(opt->options, &session);
2585 if (TF_GetCode(status) == TF_OK) {
2586 TF_Session* new_session = new TF_Session(session, graph);
2587 if (graph != nullptr) {
2588 mutex_lock l(graph->mu);
2589 graph->sessions[new_session] = "";
2590 }
2591 return new_session;
2592 } else {
2593 DCHECK_EQ(nullptr, session);
2594 return nullptr;
2595 }
2596 }
2597
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2598 TF_Session* TF_LoadSessionFromSavedModel(
2599 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2600 const char* export_dir, const char* const* tags, int tags_len,
2601 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2602 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2603 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2604 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2605 status->status = tensorflow::errors::Unimplemented(
2606 "Loading a SavedModel is not supported on mobile. File a bug at "
2607 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2608 "important to you");
2609 return nullptr;
2610 #else
2611 mutex_lock l(graph->mu);
2612 if (!graph->name_map.empty()) {
2613 status->status = InvalidArgument("Graph is non-empty.");
2614 return nullptr;
2615 }
2616
2617 RunOptions run_options_proto;
2618 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2619 run_options->data, run_options->length)) {
2620 status->status = InvalidArgument("Unparseable RunOptions proto");
2621 return nullptr;
2622 }
2623
2624 std::unordered_set<string> tag_set;
2625 for (int i = 0; i < tags_len; i++) {
2626 tag_set.insert(string(tags[i]));
2627 }
2628
2629 tensorflow::SavedModelBundle bundle;
2630 status->status =
2631 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2632 export_dir, tag_set, &bundle);
2633 if (TF_GetCode(status) != TF_OK) return nullptr;
2634
2635 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2636 // extends using GraphDefs. The Graph instance is different, but equivalent
2637 // to the one used to create the session.
2638 //
2639 // TODO(jhseu): When Session is modified to take Graphs instead of
2640 // GraphDefs, return the Graph generated in LoadSavedModel().
2641 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2642 TF_ImportGraphDefResults results;
2643 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2644 import_opts, &results, status);
2645 TF_DeleteImportGraphDefOptions(import_opts);
2646 if (TF_GetCode(status) != TF_OK) return nullptr;
2647
2648 if (meta_graph_def != nullptr) {
2649 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2650 if (TF_GetCode(status) != TF_OK) return nullptr;
2651 }
2652
2653 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2654
2655 graph->sessions[session] = "";
2656 session->last_num_graph_nodes = graph->graph.num_node_ids();
2657 return session;
2658 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2659 }
2660
TF_CloseSession(TF_Session * s,TF_Status * status)2661 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2662 status->status = s->session->Close();
2663 }
2664
TF_DeleteSession(TF_Session * s,TF_Status * status)2665 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2666 status->status = Status::OK();
2667 if (s == nullptr) return;
2668 TF_Graph* const graph = s->graph;
2669 if (graph != nullptr) {
2670 graph->mu.lock();
2671 graph->sessions.erase(s);
2672 const bool del = graph->delete_requested && graph->sessions.empty();
2673 graph->mu.unlock();
2674 if (del) delete graph;
2675 }
2676 delete s->session;
2677 delete s;
2678 }
2679
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2680 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2681 const TF_Output* inputs, TF_Tensor* const* input_values,
2682 int ninputs, const TF_Output* outputs,
2683 TF_Tensor** output_values, int noutputs,
2684 const TF_Operation* const* target_opers, int ntargets,
2685 TF_Buffer* run_metadata, TF_Status* status) {
2686 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2687 // directly, instead of requiring us to serialize to a GraphDef and
2688 // call Session::Extend().
2689 if (session->extend_before_run &&
2690 !ExtendSessionGraphHelper(session, status)) {
2691 return;
2692 }
2693
2694 TF_Run_Setup(noutputs, output_values, status);
2695
2696 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2697 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2698 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2699 for (int i = 0; i < ninputs; ++i) {
2700 input_pairs[i].first = OutputName(inputs[i]);
2701 }
2702
2703 // Convert from TF_Output to string names.
2704 std::vector<string> output_names(noutputs);
2705 for (int i = 0; i < noutputs; ++i) {
2706 output_names[i] = OutputName(outputs[i]);
2707 }
2708
2709 // Convert from TF_Operation* to string names.
2710 std::vector<string> target_names(ntargets);
2711 for (int i = 0; i < ntargets; ++i) {
2712 target_names[i] = target_opers[i]->node.name();
2713 }
2714
2715 // Actually run.
2716 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2717 output_names, output_values, target_names, run_metadata,
2718 status);
2719 }
2720
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2721 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2722 int ninputs, const TF_Output* outputs, int noutputs,
2723 const TF_Operation* const* target_opers, int ntargets,
2724 const char** handle, TF_Status* status) {
2725 *handle = nullptr;
2726
2727 if (session->extend_before_run &&
2728 !ExtendSessionGraphHelper(session, status)) {
2729 return;
2730 }
2731
2732 std::vector<string> input_names(ninputs);
2733 for (int i = 0; i < ninputs; ++i) {
2734 input_names[i] = OutputName(inputs[i]);
2735 }
2736
2737 std::vector<string> output_names(noutputs);
2738 for (int i = 0; i < noutputs; ++i) {
2739 output_names[i] = OutputName(outputs[i]);
2740 }
2741
2742 std::vector<string> target_names(ntargets);
2743 for (int i = 0; i < ntargets; ++i) {
2744 target_names[i] = target_opers[i]->node.name();
2745 }
2746
2747 string new_handle;
2748 status->status = session->session->PRunSetup(input_names, output_names,
2749 target_names, &new_handle);
2750 if (TF_GetCode(status) == TF_OK) {
2751 char* buf = new char[new_handle.size() + 1];
2752 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2753 *handle = buf;
2754 }
2755 }
2756
TF_DeletePRunHandle(const char * handle)2757 void TF_DeletePRunHandle(const char* handle) {
2758 delete[] handle;
2759 // TODO(suharshs): Free up any resources held by the partial run state.
2760 }
2761
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2762 void TF_SessionPRun(TF_Session* session, const char* handle,
2763 const TF_Output* inputs, TF_Tensor* const* input_values,
2764 int ninputs, const TF_Output* outputs,
2765 TF_Tensor** output_values, int noutputs,
2766 const TF_Operation* const* target_opers, int ntargets,
2767 TF_Status* status) {
2768 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2769 // directly, instead of requiring us to serialize to a GraphDef and
2770 // call Session::Extend().
2771 if (session->extend_before_run &&
2772 !ExtendSessionGraphHelper(session, status)) {
2773 return;
2774 }
2775
2776 TF_Run_Setup(noutputs, output_values, status);
2777
2778 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2779 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2780 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2781 for (int i = 0; i < ninputs; ++i) {
2782 input_pairs[i].first = OutputName(inputs[i]);
2783 }
2784
2785 // Convert from TF_Output to string names.
2786 std::vector<string> output_names(noutputs);
2787 for (int i = 0; i < noutputs; ++i) {
2788 output_names[i] = OutputName(outputs[i]);
2789 }
2790
2791 // Convert from TF_Operation* to string names.
2792 std::vector<string> target_names(ntargets);
2793 for (int i = 0; i < ntargets; ++i) {
2794 target_names[i] = target_opers[i]->node.name();
2795 }
2796
2797 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2798 output_values, target_names, nullptr, status);
2799 }
2800
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2801 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2802 TF_Tensor** result, TF_Status* status) {
2803 *result = nullptr;
2804 mutex_lock l(graph->mu);
2805 OutputTensor tensor(&output.oper->node, output.index);
2806 bool evaluated;
2807 Tensor result_tensor;
2808 status->status = EvaluateConstantTensor(
2809 tensor, graph->refiner, *graph->graph.op_registry(),
2810 graph->graph.versions().producer(), &evaluated, &result_tensor);
2811 if (evaluated) {
2812 DCHECK(TF_GetCode(status) == TF_OK);
2813 *result = TF_TensorFromTensor(result_tensor, status);
2814 if (TF_GetCode(status) != TF_OK) evaluated = false;
2815 }
2816 return evaluated;
2817 }
2818
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2819 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2820 tensorflow::OpList op_list;
2821 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2822 status->status = InvalidArgument("Unparseable OpList");
2823 return nullptr;
2824 }
2825 status->status = Status::OK();
2826 return new TF_ApiDefMap(op_list);
2827 }
2828
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2829 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2830
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2831 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2832 size_t text_len, TF_Status* status) {
2833 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2834 status->status = tensorflow::errors::Unimplemented(
2835 "ApiDefMap is not supported on mobile.");
2836 #else
2837 mutex_lock l(api_def_map->lock);
2838 if (api_def_map->update_docs_called) {
2839 status->status = FailedPrecondition(
2840 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2841 "called.");
2842 return;
2843 }
2844 string api_def_text(text, text_len);
2845 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2846 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2847 }
2848
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2849 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2850 size_t name_len, TF_Status* status) {
2851 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2852 status->status = tensorflow::errors::Unimplemented(
2853 "ApiDefMap is not supported on mobile.");
2854 return nullptr;
2855 #else
2856 mutex_lock l(api_def_map->lock);
2857 if (!api_def_map->update_docs_called) {
2858 api_def_map->api_def_map.UpdateDocs();
2859 api_def_map->update_docs_called = true;
2860 }
2861 string name_str(name, name_len);
2862 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2863 if (api_def == nullptr) {
2864 return nullptr;
2865 }
2866
2867 TF_Buffer* ret = TF_NewBuffer();
2868 status->status = MessageToBuffer(*api_def, ret);
2869 if (TF_GetCode(status) != TF_OK) {
2870 TF_DeleteBuffer(ret);
2871 return nullptr;
2872 }
2873 return ret;
2874 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2875 }
2876
TF_GetAllRegisteredKernels(TF_Status * status)2877 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2878 tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2879 TF_Buffer* ret = TF_NewBuffer();
2880 status->status = MessageToBuffer(kernel_list, ret);
2881 if (TF_GetCode(status) != TF_OK) {
2882 TF_DeleteBuffer(ret);
2883 return nullptr;
2884 }
2885 return ret;
2886 }
2887
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2888 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2889 tensorflow::KernelList kernel_list =
2890 tensorflow::GetRegisteredKernelsForOp(name);
2891 TF_Buffer* ret = TF_NewBuffer();
2892 status->status = MessageToBuffer(kernel_list, ret);
2893 if (TF_GetCode(status) != TF_OK) {
2894 TF_DeleteBuffer(ret);
2895 return nullptr;
2896 }
2897 return ret;
2898 }
2899
2900 // TF_Server functions ----------------------------------------------
2901
2902 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2903 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2904 : target(server->target()), server(std::move(server)) {}
2905 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2906
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2907 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2908 TF_Status* status) {
2909 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2910 status->status = tensorflow::errors::Unimplemented(
2911 "Server functionality is not supported on mobile");
2912 return nullptr;
2913 #else
2914 tensorflow::ServerDef server_def;
2915 if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2916 status->status = InvalidArgument(
2917 "Could not parse provided bytes into a ServerDef protocol buffer");
2918 return nullptr;
2919 }
2920
2921 std::unique_ptr<tensorflow::ServerInterface> out_server;
2922 status->status = tensorflow::NewServer(server_def, &out_server);
2923 if (TF_GetCode(status) != TF_OK) return nullptr;
2924
2925 return new TF_Server(std::move(out_server));
2926 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2927 }
2928
TF_ServerStart(TF_Server * server,TF_Status * status)2929 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2930 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2931 status->status = tensorflow::errors::Unimplemented(
2932 "Server functionality is not supported on mobile");
2933 #else
2934 status->status = server->server->Start();
2935 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2936 }
2937
TF_ServerStop(TF_Server * server,TF_Status * status)2938 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2939 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2940 status->status = tensorflow::errors::Unimplemented(
2941 "Server functionality is not supported on mobile");
2942 #else
2943 status->status = server->server->Stop();
2944 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2945 }
2946
TF_ServerJoin(TF_Server * server,TF_Status * status)2947 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2948 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2949 status->status = tensorflow::errors::Unimplemented(
2950 "Server functionality is not supported on mobile");
2951 #else
2952 status->status = server->server->Join();
2953 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2954 }
2955
TF_ServerTarget(TF_Server * server)2956 const char* TF_ServerTarget(TF_Server* server) {
2957 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2958 return nullptr;
2959 #else
2960 return server->target.c_str();
2961 #endif
2962 }
2963
TF_DeleteServer(TF_Server * server)2964 void TF_DeleteServer(TF_Server* server) {
2965 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2966 delete server;
2967 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2968 }
2969
TF_RegisterLogListener(void (* listener)(const char *))2970 void TF_RegisterLogListener(void (*listener)(const char*)) {
2971 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2972 tensorflow::logging::RegisterListener(listener);
2973 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2974 }
2975
2976 } // end extern "C"
2977