1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #ifndef __ANDROID__
24 #include "tensorflow/cc/framework/gradients.h"
25 #include "tensorflow/cc/framework/ops.h"
26 #include "tensorflow/cc/framework/scope_internal.h"
27 #include "tensorflow/cc/ops/while_loop.h"
28 #include "tensorflow/cc/saved_model/loader.h"
29 #include "tensorflow/core/framework/op_gen_lib.h"
30 #endif
31 #include "tensorflow/c/c_api_internal.h"
32 #include "tensorflow/core/common_runtime/device_mgr.h"
33 #include "tensorflow/core/common_runtime/shape_refiner.h"
34 #include "tensorflow/core/framework/allocation_description.pb.h"
35 #include "tensorflow/core/framework/log_memory.h"
36 #include "tensorflow/core/framework/node_def_util.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/partial_tensor_shape.h"
39 #include "tensorflow/core/framework/tensor.h"
40 #include "tensorflow/core/framework/tensor_shape.h"
41 #include "tensorflow/core/framework/tensor_shape.pb.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/framework/versions.pb.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/graph_constructor.h"
46 #include "tensorflow/core/graph/node_builder.h"
47 #include "tensorflow/core/lib/core/coding.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/core/stringpiece.h"
51 #include "tensorflow/core/lib/gtl/array_slice.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/platform/mem.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/protobuf.h"
56 #include "tensorflow/core/platform/thread_annotations.h"
57 #include "tensorflow/core/platform/types.h"
58 #include "tensorflow/core/public/session.h"
59 #include "tensorflow/core/public/version.h"
60
61 // The implementation below is at the top level instead of the
62 // brain namespace because we are defining 'extern "C"' functions.
63 using tensorflow::AllocationDescription;
64 using tensorflow::DataType;
65 using tensorflow::Graph;
66 using tensorflow::GraphDef;
67 using tensorflow::mutex_lock;
68 using tensorflow::NameRangeMap;
69 using tensorflow::NameRangesForNode;
70 using tensorflow::NewSession;
71 using tensorflow::Node;
72 using tensorflow::NodeBuilder;
73 using tensorflow::NodeDef;
74 using tensorflow::OpDef;
75 using tensorflow::OpRegistry;
76 using tensorflow::PartialTensorShape;
77 using tensorflow::RunMetadata;
78 using tensorflow::RunOptions;
79 using tensorflow::Session;
80 using tensorflow::Status;
81 using tensorflow::string;
82 using tensorflow::Tensor;
83 using tensorflow::TensorBuffer;
84 using tensorflow::TensorId;
85 using tensorflow::TensorShape;
86 using tensorflow::TensorShapeProto;
87 using tensorflow::VersionDef;
88 using tensorflow::error::Code;
89 using tensorflow::errors::FailedPrecondition;
90 using tensorflow::errors::InvalidArgument;
91 using tensorflow::gtl::ArraySlice;
92 using tensorflow::strings::StrCat;
93
94 extern "C" {
95
96 // --------------------------------------------------------------------------
TF_Version()97 const char* TF_Version() { return TF_VERSION_STRING; }
98
99 // --------------------------------------------------------------------------
TF_DataTypeSize(TF_DataType dt)100 size_t TF_DataTypeSize(TF_DataType dt) {
101 return static_cast<size_t>(
102 tensorflow::DataTypeSize(static_cast<DataType>(dt)));
103 }
104
105 // --------------------------------------------------------------------------
106
TF_NewStatus()107 TF_Status* TF_NewStatus() { return new TF_Status; }
108
TF_DeleteStatus(TF_Status * s)109 void TF_DeleteStatus(TF_Status* s) { delete s; }
110
TF_SetStatus(TF_Status * s,TF_Code code,const char * msg)111 void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) {
112 if (code == TF_OK) {
113 s->status = Status::OK();
114 return;
115 }
116 s->status = Status(static_cast<Code>(code), tensorflow::StringPiece(msg));
117 }
118
TF_GetCode(const TF_Status * s)119 TF_Code TF_GetCode(const TF_Status* s) {
120 return static_cast<TF_Code>(s->status.code());
121 }
122
TF_Message(const TF_Status * s)123 const char* TF_Message(const TF_Status* s) {
124 return s->status.error_message().c_str();
125 }
126
127 // --------------------------------------------------------------------------
128
129 namespace {
130 class TF_ManagedBuffer : public TensorBuffer {
131 public:
132 void* data_;
133 size_t len_;
134 void (*deallocator_)(void* data, size_t len, void* arg);
135 void* deallocator_arg_;
136
~TF_ManagedBuffer()137 ~TF_ManagedBuffer() override {
138 (*deallocator_)(data_, len_, deallocator_arg_);
139 }
140
data() const141 void* data() const override { return data_; }
size() const142 size_t size() const override { return len_; }
root_buffer()143 TensorBuffer* root_buffer() override { return this; }
FillAllocationDescription(AllocationDescription * proto) const144 void FillAllocationDescription(AllocationDescription* proto) const override {
145 tensorflow::int64 rb = size();
146 proto->set_requested_bytes(rb);
147 proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
148 }
149
150 // Prevents input forwarding from mutating this buffer.
OwnsMemory() const151 bool OwnsMemory() const override { return false; }
152 };
153
allocate_tensor(const char * operation,size_t len)154 void* allocate_tensor(const char* operation, size_t len) {
155 void* data =
156 tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len);
157 if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
158 tensorflow::LogMemory::RecordRawAllocation(
159 operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID,
160 len, data, tensorflow::cpu_allocator());
161 }
162 return data;
163 }
164
deallocate_buffer(void * data,size_t len,void * arg)165 void deallocate_buffer(void* data, size_t len, void* arg) {
166 if (tensorflow::LogMemory::IsEnabled() && data != nullptr) {
167 tensorflow::LogMemory::RecordRawDeallocation(
168 "TensorFlow C Api",
169 tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data,
170 tensorflow::cpu_allocator(), false);
171 }
172 tensorflow::cpu_allocator()->DeallocateRaw(data);
173 }
174
175 } // namespace
176
~TF_Tensor()177 TF_Tensor::~TF_Tensor() { buffer->Unref(); }
178
TF_AllocateTensor(TF_DataType dtype,const int64_t * dims,int num_dims,size_t len)179 TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
180 int num_dims, size_t len) {
181 void* data = allocate_tensor("TF_AllocateTensor", len);
182 return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer,
183 nullptr);
184 }
185
TF_NewTensor(TF_DataType dtype,const int64_t * dims,int num_dims,void * data,size_t len,void (* deallocator)(void * data,size_t len,void * arg),void * deallocator_arg)186 TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
187 void* data, size_t len,
188 void (*deallocator)(void* data, size_t len, void* arg),
189 void* deallocator_arg) {
190 std::vector<tensorflow::int64> dimvec(num_dims);
191 for (int i = 0; i < num_dims; ++i) {
192 dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
193 }
194
195 TF_ManagedBuffer* buf = new TF_ManagedBuffer;
196 buf->len_ = len;
197 if (dtype != TF_STRING && dtype != TF_RESOURCE &&
198 tensorflow::DataTypeCanUseMemcpy(static_cast<DataType>(dtype)) &&
199 reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
200 // TF_STRING and TF_RESOURCE tensors have a different representation in
201 // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
202 // (any alignment requirements will be taken care of by TF_TensorToTensor
203 // and TF_TensorFromTensor).
204 //
205 // Other types have the same representation, so copy only if it is safe to
206 // do so.
207 buf->data_ = allocate_tensor("TF_NewTensor", len);
208 std::memcpy(buf->data_, data, len);
209 buf->deallocator_ = deallocate_buffer;
210 buf->deallocator_arg_ = nullptr;
211 // Free the original buffer.
212 deallocator(data, len, deallocator_arg);
213 } else {
214 buf->data_ = data;
215 buf->deallocator_ = deallocator;
216 buf->deallocator_arg_ = deallocator_arg;
217 }
218 TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf};
219 size_t elem_size = TF_DataTypeSize(dtype);
220 if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) {
221 delete ret;
222 return nullptr;
223 }
224 return ret;
225 }
226
TF_TensorMaybeMove(TF_Tensor * tensor)227 TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
228 // It is safe to move the Tensor if and only if we own the unique reference to
229 // it. In that case, we might as well not delete and reallocate, but a future
230 // implementation might need to do so.
231 TensorBuffer* buf = tensor->buffer;
232 if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
233 buf->OwnsMemory()) {
234 return tensor;
235 }
236 return nullptr;
237 }
238
TF_DeleteTensor(TF_Tensor * t)239 void TF_DeleteTensor(TF_Tensor* t) { delete t; }
240
TF_TensorType(const TF_Tensor * t)241 TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
TF_NumDims(const TF_Tensor * t)242 int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
TF_Dim(const TF_Tensor * t,int dim_index)243 int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
244 return static_cast<int64_t>(t->shape.dim_size(dim_index));
245 }
TF_TensorByteSize(const TF_Tensor * t)246 size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
TF_TensorData(const TF_Tensor * t)247 void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
248
249 // --------------------------------------------------------------------------
TF_StringEncode(const char * src,size_t src_len,char * dst,size_t dst_len,TF_Status * status)250 size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
251 size_t dst_len, TF_Status* status) {
252 const size_t sz = TF_StringEncodedSize(src_len);
253 if (sz < src_len) {
254 status->status = InvalidArgument("src string is too large to encode");
255 return 0;
256 }
257 if (dst_len < sz) {
258 status->status =
259 InvalidArgument("dst_len (", dst_len, ") too small to encode a ",
260 src_len, "-byte string");
261 return 0;
262 }
263 dst = tensorflow::core::EncodeVarint64(dst, src_len);
264 memcpy(dst, src, src_len);
265 return sz;
266 }
267
TF_StringDecode_Impl(const char * src,size_t src_len,const char ** dst,size_t * dst_len)268 static Status TF_StringDecode_Impl(const char* src, size_t src_len,
269 const char** dst, size_t* dst_len) {
270 tensorflow::uint64 len64 = 0;
271 const char* p = tensorflow::core::GetVarint64Ptr(src, src + src_len, &len64);
272 if (p == nullptr) {
273 return InvalidArgument("invalid string encoding or truncated src buffer");
274 }
275 if (len64 > std::numeric_limits<size_t>::max()) {
276 return InvalidArgument("encoded string is ", len64,
277 "-bytes, which is too large for this architecture");
278 }
279 *dst = p;
280 *dst_len = static_cast<size_t>(len64);
281 return Status::OK();
282 }
283
TF_StringDecode(const char * src,size_t src_len,const char ** dst,size_t * dst_len,TF_Status * status)284 size_t TF_StringDecode(const char* src, size_t src_len, const char** dst,
285 size_t* dst_len, TF_Status* status) {
286 status->status = TF_StringDecode_Impl(src, src_len, dst, dst_len);
287 if (!status->status.ok()) return 0;
288 return static_cast<size_t>(*dst - src) + *dst_len;
289 }
290
TF_StringEncodedSize(size_t len)291 size_t TF_StringEncodedSize(size_t len) {
292 return static_cast<size_t>(tensorflow::core::VarintLength(len)) + len;
293 }
294
295 // --------------------------------------------------------------------------
TF_NewSessionOptions()296 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)297 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
298
TF_SetTarget(TF_SessionOptions * options,const char * target)299 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
300 options->options.target = target;
301 }
302
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)303 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
304 size_t proto_len, TF_Status* status) {
305 if (!options->options.config.ParseFromArray(proto, proto_len)) {
306 status->status = InvalidArgument("Unparseable ConfigProto");
307 }
308 }
309 // --------------------------------------------------------------------------
TF_NewBuffer()310 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
311
TF_NewBufferFromString(const void * proto,size_t proto_len)312 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
313 void* copy = tensorflow::port::Malloc(proto_len);
314 memcpy(copy, proto, proto_len);
315
316 TF_Buffer* buf = new TF_Buffer;
317 buf->data = copy;
318 buf->length = proto_len;
319 buf->data_deallocator = [](void* data, size_t length) {
320 tensorflow::port::Free(data);
321 };
322 return buf;
323 }
324
TF_DeleteBuffer(TF_Buffer * buffer)325 void TF_DeleteBuffer(TF_Buffer* buffer) {
326 if (buffer->data_deallocator != nullptr) {
327 (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
328 buffer->length);
329 }
330 delete buffer;
331 }
332
TF_GetBuffer(TF_Buffer * buffer)333 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
334
335 // --------------------------------------------------------------------------
336
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)337 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
338 TF_Status* status) {
339 Session* session;
340 status->status = NewSession(opt->options, &session);
341 if (status->status.ok()) {
342 return new TF_DeprecatedSession({session});
343 } else {
344 DCHECK_EQ(nullptr, session);
345 return nullptr;
346 }
347 }
348
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)349 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
350 status->status = s->session->Close();
351 }
352
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)353 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
354 status->status = Status::OK();
355 delete s->session;
356 delete s;
357 }
358
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)359 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
360 size_t proto_len, TF_Status* status) {
361 GraphDef g;
362 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
363 status->status = InvalidArgument("Invalid GraphDef");
364 return;
365 }
366 status->status = s->session->Extend(g);
367 }
368
DeleteArray(void * data,size_t size,void * arg)369 static void DeleteArray(void* data, size_t size, void* arg) {
370 DCHECK_EQ(data, arg);
371 delete[] reinterpret_cast<char*>(arg);
372 }
373
374 } // end extern "C"
375
376 namespace tensorflow {
377 namespace {
378
379 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)380 void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
381 int ncontainers, TF_Status* status) {
382 std::vector<string> container_names(ncontainers);
383 for (int i = 0; i < ncontainers; ++i) {
384 container_names[i] = containers[i];
385 }
386
387 status->status = Reset(opt->options, container_names);
388 }
389
390 // This traverses the specified nodes in topological order to verify there are
391 // no cycles. Starting with inputless nodes, it visits nodes whose inputs have
392 // all been visited, and counts the total number of visited nodes. If there is a
393 // cycle, nodes in the cycle will never be visited, and the visited count will
394 // be less than the total node count.
ValidateNoCycles(const Graph & g)395 Status ValidateNoCycles(const Graph& g) {
396 // TODO(nolivia): check this on a subset of the graph instead of all of it.
397 // A node is ready when all of its inputs have been visited.
398 std::vector<const Node*> ready;
399 std::vector<int> pending_count(g.num_node_ids(), 0);
400
401 for (int i = 0; i < g.num_node_ids(); ++i) {
402 const Node* n = g.FindNodeId(i);
403 if (n == nullptr) continue;
404 pending_count[i] = n->in_edges().size();
405 if (n->IsMerge()) {
406 // While-loop cycles are legal cycles so we manually adjust the
407 // pending_count to make sure that the loop is visited.
408 for (const Edge* e : n->in_edges()) {
409 if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
410 pending_count[i]--;
411 }
412 }
413 }
414 if (pending_count[i] == 0) {
415 ready.push_back(n);
416 }
417 }
418
419 int processed = 0;
420 while (!ready.empty()) {
421 const Node* node = ready.back();
422 ready.pop_back();
423 ++processed;
424
425 for (const Edge* out : node->out_edges()) {
426 const int output_id = out->dst()->id();
427 pending_count[output_id]--;
428 if (pending_count[output_id] == 0) {
429 ready.push_back(out->dst());
430 }
431 }
432 }
433
434 if (processed < g.num_nodes()) {
435 std::vector<string> nodes_in_cycle;
436 for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
437 ++i) {
438 if (pending_count[i] != 0) {
439 nodes_in_cycle.push_back(g.FindNodeId(i)->name());
440 }
441 }
442 return errors::InvalidArgument(
443 "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
444 " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
445 }
446 return Status::OK();
447 }
448 } // namespace
449 } // namespace tensorflow
450
451 extern "C" {
452
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)453 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
454 int ncontainers, TF_Status* status) {
455 tensorflow::TF_Reset_Helper(opt, containers, ncontainers, status);
456 }
457
458 } // end extern "C"
459
460 namespace tensorflow {
461
TF_TensorToTensor(const TF_Tensor * src,Tensor * dst)462 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
463 if (src->dtype == TF_RESOURCE) {
464 if (src->shape.dims() != 0) {
465 return InvalidArgument(
466 "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
467 "shape ",
468 src->shape.DebugString());
469 }
470 *dst = Tensor(DT_RESOURCE, src->shape);
471 if (!dst->scalar<ResourceHandle>()().ParseFromString(
472 string(static_cast<const char*>(TF_TensorData(src)),
473 TF_TensorByteSize(src)))) {
474 return InvalidArgument(
475 "Malformed TF_RESOUCE tensor: unable to parse resource handle");
476 }
477 return Status::OK();
478 }
479 if (src->dtype != TF_STRING) {
480 *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
481 return Status::OK();
482 }
483 // TF_STRING tensors require copying since Tensor class expects a sequence of
484 // string objects.
485 const tensorflow::int64 num_elements = src->shape.num_elements();
486 const char* input = reinterpret_cast<const char*>(TF_TensorData(src));
487 const size_t src_size = TF_TensorByteSize(src);
488 if (static_cast<tensorflow::int64>(src_size / sizeof(tensorflow::uint64)) <
489 num_elements) {
490 return InvalidArgument(
491 "Malformed TF_STRING tensor; too short to hold number of elements");
492 }
493 const char* data_start = input + sizeof(tensorflow::uint64) * num_elements;
494 const char* limit = input + src_size;
495
496 *dst = Tensor(static_cast<DataType>(src->dtype), src->shape);
497 auto dstarray = dst->flat<string>();
498 for (tensorflow::int64 i = 0; i < num_elements; ++i) {
499 tensorflow::uint64 offset =
500 reinterpret_cast<const tensorflow::uint64*>(input)[i];
501 if (static_cast<ptrdiff_t>(offset) >= (limit - data_start)) {
502 return InvalidArgument("Malformed TF_STRING tensor; element ", i,
503 " out of range");
504 }
505 size_t len;
506 const char* p;
507 const char* srcp = data_start + offset;
508 Status status = TF_StringDecode_Impl(srcp, limit - srcp, &p, &len);
509 if (!status.ok()) return status;
510 dstarray(i).assign(p, len);
511 }
512 return Status::OK();
513 }
514
515 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
516 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const TensorShape & shape)517 static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
518 static char empty;
519 tensorflow::int64 nelems = 1;
520 std::vector<tensorflow::int64> dims;
521 for (int i = 0; i < shape.dims(); ++i) {
522 dims.push_back(shape.dim_size(i));
523 nelems *= shape.dim_size(i);
524 }
525 CHECK_EQ(nelems, 0);
526 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
527 "64-bit int types should match in size");
528 return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
529 shape.dims(), reinterpret_cast<void*>(&empty), 0,
530 [](void*, size_t, void*) {}, nullptr);
531 }
532
533 // Non-static for testing.
TF_TensorFromTensor(const tensorflow::Tensor & src,TF_Status * status)534 TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
535 TF_Status* status) {
536 if (!src.IsInitialized()) {
537 status->status = FailedPrecondition(
538 "attempt to use a tensor with an uninitialized value");
539 return nullptr;
540 }
541 if (src.NumElements() == 0) {
542 return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
543 }
544 if (src.dtype() == DT_RESOURCE) {
545 if (src.shape().dims() != 0) {
546 status->status = InvalidArgument(
547 "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
548 src.shape().DebugString(),
549 "). Please file a bug at "
550 "https://github.com/tensorflow/tensorflow/issues/new, "
551 "ideally with a "
552 "short code snippet that reproduces this error.");
553 return nullptr;
554 }
555 const string str = src.scalar<ResourceHandle>()().SerializeAsString();
556 TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
557 std::memcpy(TF_TensorData(t), str.c_str(), str.size());
558 return t;
559 }
560 if (src.dtype() != DT_STRING) {
561 TensorBuffer* buf = TensorCApi::Buffer(src);
562 buf->Ref();
563 return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
564 buf};
565 }
566 // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
567 // encoded sequence of strings.
568
569 // Compute bytes needed for encoding.
570 size_t size = 0;
571 const auto& srcarray = src.flat<string>();
572 for (int i = 0; i < srcarray.size(); ++i) {
573 const string& s = srcarray(i);
574 // uint64 starting_offset, TF_StringEncode-d string.
575 size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size());
576 }
577
578 // Encode all strings.
579 char* base = new char[size];
580 char* data_start = base + sizeof(tensorflow::uint64) * srcarray.size();
581 char* dst = data_start; // Where next string is encoded.
582 size_t dst_len = size - static_cast<size_t>(data_start - base);
583 tensorflow::uint64* offsets = reinterpret_cast<tensorflow::uint64*>(base);
584 for (int i = 0; i < srcarray.size(); ++i) {
585 *offsets = (dst - data_start);
586 offsets++;
587 const string& s = srcarray(i);
588 size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status);
589 if (!status->status.ok()) {
590 status->status = InvalidArgument(
591 "invalid string tensor encoding (string #", i, " of ",
592 srcarray.size(), "): ", status->status.error_message());
593 delete[] base;
594 return nullptr;
595 }
596 dst += consumed;
597 dst_len -= consumed;
598 }
599 if (dst != base + size) {
600 status->status = InvalidArgument(
601 "invalid string tensor encoding (decoded ", (dst - base),
602 " bytes, but the tensor is encoded in ", size, " bytes");
603 delete[] base;
604 return nullptr;
605 }
606
607 auto dims = src.shape().dim_sizes();
608 std::vector<tensorflow::int64> dimvec(dims.size());
609 for (size_t i = 0; i < dims.size(); ++i) {
610 dimvec[i] = dims[i];
611 }
612 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
613 "64-bit int types should match in size");
614 return TF_NewTensor(TF_STRING,
615 reinterpret_cast<const int64_t*>(dimvec.data()),
616 dimvec.size(), base, size, DeleteArray, base);
617 }
618
MessageToBuffer(const tensorflow::protobuf::Message & in,TF_Buffer * out)619 Status MessageToBuffer(const tensorflow::protobuf::Message& in,
620 TF_Buffer* out) {
621 if (out->data != nullptr) {
622 return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
623 }
624 const size_t proto_size = in.ByteSizeLong();
625 void* buf = tensorflow::port::Malloc(proto_size);
626 if (buf == nullptr) {
627 return tensorflow::errors::ResourceExhausted(
628 "Failed to allocate memory to serialize message of type '",
629 in.GetTypeName(), "' and size ", proto_size);
630 }
631 in.SerializeToArray(buf, proto_size);
632 out->data = buf;
633 out->length = proto_size;
634 out->data_deallocator = [](void* data, size_t length) {
635 tensorflow::port::Free(data);
636 };
637 return Status::OK();
638 }
639
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)640 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
641 const char* mutation_type)
642 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
643 // If any session has already run this node_id, mark this session as
644 // unrunnable.
645 for (auto it : graph->sessions) {
646 if (it.first->last_num_graph_nodes > op.node.id()) {
647 it.second = FailedPrecondition(
648 "Operation '", op.node.DebugString(), "' was changed by ",
649 mutation_type,
650 " after it was run by a session. Nodes can be mutated "
651 "only before they are executed by a session. Either don't modify "
652 "nodes after running them or create a new session.");
653 }
654 }
655 }
656
657 namespace {
658
659 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)660 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
661 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
662 const int64_t* dims) {
663 if (num_dims != -1) {
664 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
665 dim_vec.reserve(num_dims);
666 for (int i = 0; i < num_dims; ++i) {
667 dim_vec.push_back(ic->MakeDim(dims[i]));
668 }
669 return ic->MakeShape(dim_vec);
670 } else {
671 return ic->UnknownShape();
672 }
673 }
674
675 } // namespace
676
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)677 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
678 int num_shapes_and_types,
679 const int64_t** shapes,
680 const int* ranks,
681 const TF_DataType* types,
682 TF_Status* status) {
683 Node* node = &output.oper->node;
684
685 mutex_lock l(graph->mu);
686 tensorflow::shape_inference::InferenceContext* ic =
687 graph->refiner.GetContext(node);
688 if (ic == nullptr) {
689 status->status =
690 InvalidArgument("Node ", node->name(), " was not found in the graph");
691 return;
692 }
693
694 auto shape_and_type_vec =
695 std::vector<tensorflow::shape_inference::ShapeAndType>(
696 num_shapes_and_types);
697 for (int i = 0; i < num_shapes_and_types; ++i) {
698 tensorflow::shape_inference::ShapeHandle shape_handle =
699 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
700 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
701 shape_handle, static_cast<DataType>(types[i]));
702 }
703
704 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
705 }
706
707 // Helpers for loading a TensorFlow plugin (a .so file).
708 Status LoadLibrary(const char* library_filename, void** result,
709 const void** buf, size_t* len);
710
711 } // namespace tensorflow
712
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)713 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
714 TF_Status* status) {
715 status->status = Status::OK();
716 for (int i = 0; i < noutputs; ++i) {
717 c_outputs[i] = nullptr;
718 }
719 }
720
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)721 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
722 std::vector<std::pair<string, Tensor>>* input_pairs,
723 TF_Status* status) {
724 const int ninputs = input_pairs->size();
725 for (int i = 0; i < ninputs; ++i) {
726 status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
727 if (!status->status.ok()) return false;
728 }
729 return true;
730 }
731
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)732 static void TF_Run_Helper(
733 Session* session, const char* handle, const TF_Buffer* run_options,
734 // Input tensors
735 const std::vector<std::pair<string, Tensor>>& input_pairs,
736 // Output tensors
737 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
738 // Target nodes
739 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
740 TF_Status* status) {
741 const int noutputs = output_tensor_names.size();
742 std::vector<Tensor> outputs(noutputs);
743 Status result;
744
745 if (handle == nullptr) {
746 RunOptions run_options_proto;
747 if (run_options != nullptr && !run_options_proto.ParseFromArray(
748 run_options->data, run_options->length)) {
749 status->status = InvalidArgument("Unparseable RunOptions proto");
750 return;
751 }
752 if (run_metadata != nullptr && run_metadata->data != nullptr) {
753 status->status =
754 InvalidArgument("Passing non-empty run_metadata is invalid.");
755 return;
756 }
757
758 RunMetadata run_metadata_proto;
759 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
760 target_oper_names, &outputs, &run_metadata_proto);
761
762 // Serialize back to upstream client, who now owns the new buffer
763 if (run_metadata != nullptr) {
764 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
765 if (!status->status.ok()) return;
766 }
767 } else {
768 // NOTE(zongheng): PRun does not support RunOptions yet.
769 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
770 }
771 if (!result.ok()) {
772 status->status = result;
773 return;
774 }
775
776 // Store results in c_outputs[]
777 for (int i = 0; i < noutputs; ++i) {
778 const Tensor& src = outputs[i];
779 if (!src.IsInitialized() || src.NumElements() == 0) {
780 c_outputs[i] =
781 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
782 continue;
783 }
784 c_outputs[i] = TF_TensorFromTensor(src, status);
785 if (!status->status.ok()) return;
786 }
787 }
788
789 extern "C" {
790
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)791 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
792 // Input tensors
793 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
794 // Output tensors
795 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
796 // Target nodes
797 const char** c_target_oper_names, int ntargets,
798 TF_Buffer* run_metadata, TF_Status* status) {
799 TF_Run_Setup(noutputs, c_outputs, status);
800 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
801 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
802 for (int i = 0; i < ninputs; ++i) {
803 input_pairs[i].first = c_input_names[i];
804 }
805 std::vector<string> output_names(noutputs);
806 for (int i = 0; i < noutputs; ++i) {
807 output_names[i] = c_output_names[i];
808 }
809 std::vector<string> target_oper_names(ntargets);
810 for (int i = 0; i < ntargets; ++i) {
811 target_oper_names[i] = c_target_oper_names[i];
812 }
813 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
814 c_outputs, target_oper_names, run_metadata, status);
815 }
816
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)817 void TF_PRunSetup(TF_DeprecatedSession* s,
818 // Input names
819 const char** c_input_names, int ninputs,
820 // Output names
821 const char** c_output_names, int noutputs,
822 // Target nodes
823 const char** c_target_oper_names, int ntargets,
824 const char** handle, TF_Status* status) {
825 *handle = nullptr;
826
827 std::vector<string> input_names(ninputs);
828 std::vector<string> output_names(noutputs);
829 std::vector<string> target_oper_names(ntargets);
830 for (int i = 0; i < ninputs; ++i) {
831 input_names[i] = c_input_names[i];
832 }
833 for (int i = 0; i < noutputs; ++i) {
834 output_names[i] = c_output_names[i];
835 }
836 for (int i = 0; i < ntargets; ++i) {
837 target_oper_names[i] = c_target_oper_names[i];
838 }
839 string new_handle;
840 status->status = s->session->PRunSetup(input_names, output_names,
841 target_oper_names, &new_handle);
842 if (status->status.ok()) {
843 char* buf = new char[new_handle.size() + 1];
844 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
845 *handle = buf;
846 }
847 }
848
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)849 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
850 // Input tensors
851 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
852 // Output tensors
853 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
854 // Target nodes
855 const char** c_target_oper_names, int ntargets,
856 TF_Status* status) {
857 TF_Run_Setup(noutputs, c_outputs, status);
858 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
859 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
860 for (int i = 0; i < ninputs; ++i) {
861 input_pairs[i].first = c_input_names[i];
862 }
863
864 std::vector<string> output_names(noutputs);
865 for (int i = 0; i < noutputs; ++i) {
866 output_names[i] = c_output_names[i];
867 }
868 std::vector<string> target_oper_names(ntargets);
869 for (int i = 0; i < ntargets; ++i) {
870 target_oper_names[i] = c_target_oper_names[i];
871 }
872 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
873 c_outputs, target_oper_names, nullptr, status);
874 }
875
TF_LoadLibrary(const char * library_filename,TF_Status * status)876 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
877 TF_Library* lib_handle = new TF_Library;
878 status->status = tensorflow::LoadLibrary(
879 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
880 &lib_handle->op_list.length);
881 if (!status->status.ok()) {
882 delete lib_handle;
883 return nullptr;
884 }
885 return lib_handle;
886 }
887
TF_GetOpList(TF_Library * lib_handle)888 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
889
TF_DeleteLibraryHandle(TF_Library * lib_handle)890 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
891 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
892 delete lib_handle;
893 }
894
TF_GetAllOpList()895 TF_Buffer* TF_GetAllOpList() {
896 std::vector<tensorflow::OpDef> op_defs;
897 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
898 tensorflow::OpList op_list;
899 for (const auto& op : op_defs) {
900 *(op_list.add_op()) = op;
901 }
902 TF_Buffer* ret = TF_NewBuffer();
903 TF_CHECK_OK(MessageToBuffer(op_list, ret));
904 return ret;
905 }
906
907 // --------------------------------------------------------------------------
908 // ListDevices & SessionListDevices API
909
TF_DeleteDeviceList(TF_DeviceList * s)910 void TF_DeleteDeviceList(TF_DeviceList* s) { delete s; }
911
TF_SessionListDevices(TF_Session * session,TF_Status * status)912 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
913 TF_DeviceList* response = new TF_DeviceList;
914 status->status = session->session->ListDevices(&response->response);
915 return response;
916 }
917
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)918 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
919 TF_Status* status) {
920 TF_DeviceList* response = new TF_DeviceList;
921 status->status = session->session->ListDevices(&response->response);
922 return response;
923 }
924
TF_DeviceListCount(const TF_DeviceList * list)925 int TF_DeviceListCount(const TF_DeviceList* list) {
926 return list->response.size();
927 }
928
929 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
930 return_type method_name(const TF_DeviceList* list, const int index, \
931 TF_Status* status) { \
932 if (list == nullptr) { \
933 status->status = InvalidArgument("list is null!"); \
934 return err_val; \
935 } \
936 if (index < 0 || index >= list->response.size()) { \
937 status->status = InvalidArgument("index out of bounds"); \
938 return err_val; \
939 } \
940 status->status = Status::OK(); \
941 return list->response[index].accessor; \
942 }
943
944 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
945 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
946 nullptr);
947 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
948
949 #undef TF_DEVICELIST_METHOD
950
951 } // end extern "C"
952
953 // --------------------------------------------------------------------------
954 // New Graph and Session API
955
956 // Helper functions -----------------------------------------------------------
957
958 namespace {
959
ToOperation(Node * node)960 TF_Operation* ToOperation(Node* node) {
961 return static_cast<TF_Operation*>(static_cast<void*>(node));
962 }
963
OutputName(const TF_Output & output)964 string OutputName(const TF_Output& output) {
965 return StrCat(output.oper->node.name(), ":", output.index);
966 }
967
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)968 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
969 const char* attr_name,
970 TF_Status* status) {
971 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
972 if (attr == nullptr) {
973 status->status = InvalidArgument("Operation '", oper->node.name(),
974 "' has no attr named '", attr_name, "'.");
975 }
976 return attr;
977 }
978
ToTensorId(const TF_Output & output)979 TensorId ToTensorId(const TF_Output& output) {
980 return TensorId(output.oper->node.name(), output.index);
981 }
982
983 #ifndef __ANDROID__
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)984 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
985 int n) {
986 std::vector<tensorflow::Output> outputs(n);
987 for (int i = 0; i < n; ++i) {
988 outputs[i] =
989 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
990 }
991 return outputs;
992 }
993
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)994 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
995 TF_Output* tf_outputs) {
996 for (int i = 0; i < outputs.size(); i++) {
997 tf_outputs[i].oper = ToOperation(outputs[i].node());
998 tf_outputs[i].index = outputs[i].index();
999 }
1000 }
1001 #endif // __ANDROID__
1002
1003 } // namespace
1004
1005 // Shape functions -----------------------------------------------------------
1006
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)1007 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
1008 const int64_t* dims, const int num_dims,
1009 TF_Status* status) {
1010 Node* node = &output.oper->node;
1011
1012 mutex_lock l(graph->mu);
1013 tensorflow::shape_inference::InferenceContext* ic =
1014 graph->refiner.GetContext(node);
1015 if (ic == nullptr) {
1016 status->status =
1017 InvalidArgument("Node ", node->name(), " was not found in the graph");
1018 return;
1019 }
1020 tensorflow::shape_inference::ShapeHandle new_shape =
1021 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
1022 status->status = graph->refiner.SetShape(node, output.index, new_shape);
1023 }
1024
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)1025 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
1026 TF_Status* status) {
1027 Node* node = &output.oper->node;
1028
1029 mutex_lock l(graph->mu);
1030 tensorflow::shape_inference::InferenceContext* ic =
1031 graph->refiner.GetContext(node);
1032 if (ic == nullptr) {
1033 status->status =
1034 InvalidArgument("Node ", node->name(), " was not found in the graph");
1035 return -1;
1036 }
1037
1038 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1039
1040 // Unknown rank means the number of dimensions is -1.
1041 if (!ic->RankKnown(shape)) {
1042 return -1;
1043 }
1044
1045 return ic->Rank(shape);
1046 }
1047
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)1048 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
1049 int num_dims, TF_Status* status) {
1050 Node* node = &output.oper->node;
1051
1052 mutex_lock l(graph->mu);
1053 tensorflow::shape_inference::InferenceContext* ic =
1054 graph->refiner.GetContext(node);
1055 if (ic == nullptr) {
1056 status->status =
1057 InvalidArgument("Node ", node->name(), " was not found in the graph");
1058 return;
1059 }
1060
1061 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
1062
1063 int rank = -1;
1064 if (ic->RankKnown(shape)) {
1065 rank = ic->Rank(shape);
1066 }
1067
1068 if (num_dims != rank) {
1069 status->status = InvalidArgument("Expected rank is ", num_dims,
1070 " but actual rank is ", rank);
1071 return;
1072 }
1073
1074 if (num_dims == 0) {
1075 // Output shape is a scalar.
1076 return;
1077 }
1078
1079 // Rank is greater than 0, so fill in the values, if known, and
1080 // -1 for unknown values.
1081 for (int i = 0; i < num_dims; ++i) {
1082 auto dim = ic->Dim(shape, i);
1083 tensorflow::int64 value = -1;
1084 if (ic->ValueKnown(dim)) {
1085 value = ic->Value(dim);
1086 }
1087 dims[i] = value;
1088 }
1089 }
1090
1091 // TF_OperationDescription functions ------------------------------------------
1092
1093 extern "C" {
1094
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)1095 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
1096 const char* op_type,
1097 const char* oper_name)
1098 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1099 return new TF_OperationDescription(graph, op_type, oper_name);
1100 }
1101
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)1102 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
1103 const char* oper_name) {
1104 mutex_lock l(graph->mu);
1105 return TF_NewOperationLocked(graph, op_type, oper_name);
1106 }
1107
TF_SetDevice(TF_OperationDescription * desc,const char * device)1108 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
1109 desc->node_builder.Device(device);
1110 }
1111
TF_AddInput(TF_OperationDescription * desc,TF_Output input)1112 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
1113 desc->node_builder.Input(&input.oper->node, input.index);
1114 }
1115
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)1116 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
1117 int num_inputs) {
1118 std::vector<NodeBuilder::NodeOut> input_list;
1119 input_list.reserve(num_inputs);
1120 for (int i = 0; i < num_inputs; ++i) {
1121 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
1122 }
1123 desc->node_builder.Input(input_list);
1124 }
1125
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)1126 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
1127 desc->node_builder.ControlInput(&input->node);
1128 }
1129
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)1130 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
1131 desc->colocation_constraints.emplace(
1132 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
1133 }
1134
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)1135 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
1136 const void* value, size_t length) {
1137 tensorflow::StringPiece s(static_cast<const char*>(value), length);
1138 desc->node_builder.Attr(attr_name, s);
1139 }
1140
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)1141 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
1142 const void* const* values, const size_t* lengths,
1143 int num_values) {
1144 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1145 desc->colocation_constraints.clear();
1146 for (int i = 0; i < num_values; ++i) {
1147 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
1148 lengths[i]);
1149 }
1150 } else {
1151 std::vector<tensorflow::StringPiece> v;
1152 v.reserve(num_values);
1153 for (int i = 0; i < num_values; ++i) {
1154 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
1155 }
1156 desc->node_builder.Attr(attr_name, v);
1157 }
1158 }
1159
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)1160 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
1161 int64_t value) {
1162 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1163 "64-bit int types should match in size");
1164 desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
1165 }
1166
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)1167 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
1168 const int64_t* values, int num_values) {
1169 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1170 "64-bit int types should match in size");
1171 desc->node_builder.Attr(
1172 attr_name,
1173 ArraySlice<const tensorflow::int64>(
1174 reinterpret_cast<const tensorflow::int64*>(values), num_values));
1175 }
1176
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)1177 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
1178 float value) {
1179 desc->node_builder.Attr(attr_name, value);
1180 }
1181
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)1182 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
1183 const float* values, int num_values) {
1184 desc->node_builder.Attr(attr_name,
1185 ArraySlice<const float>(values, num_values));
1186 }
1187
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)1188 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
1189 unsigned char value) {
1190 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
1191 }
1192
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)1193 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
1194 const unsigned char* values, int num_values) {
1195 std::unique_ptr<bool[]> b(new bool[num_values]);
1196 for (int i = 0; i < num_values; ++i) {
1197 b[i] = values[i];
1198 }
1199 desc->node_builder.Attr(attr_name,
1200 ArraySlice<const bool>(b.get(), num_values));
1201 }
1202
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)1203 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
1204 TF_DataType value) {
1205 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
1206 }
1207
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)1208 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
1209 const TF_DataType* values, int num_values) {
1210 desc->node_builder.Attr(
1211 attr_name, ArraySlice<const DataType>(
1212 reinterpret_cast<const DataType*>(values), num_values));
1213 }
1214
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)1215 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
1216 const char* value, size_t length) {
1217 tensorflow::NameAttrList func_name;
1218 func_name.set_name(std::string(value, value + length));
1219 desc->node_builder.Attr(attr_name, func_name);
1220 }
1221
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)1222 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
1223 const int64_t* dims, int num_dims) {
1224 PartialTensorShape shape;
1225 if (num_dims >= 0) {
1226 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1227 "64-bit int types should match in size");
1228 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
1229 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
1230 }
1231 desc->node_builder.Attr(attr_name, shape);
1232 }
1233
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)1234 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
1235 const int64_t* const* dims, const int* num_dims,
1236 int num_shapes) {
1237 std::vector<PartialTensorShape> shapes;
1238 shapes.reserve(num_shapes);
1239 for (int i = 0; i < num_shapes; ++i) {
1240 if (num_dims[i] < 0) {
1241 shapes.emplace_back();
1242 } else {
1243 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
1244 "64-bit int types should match in size");
1245 shapes.emplace_back(ArraySlice<tensorflow::int64>(
1246 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
1247 }
1248 }
1249 desc->node_builder.Attr(attr_name, shapes);
1250 }
1251
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1252 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
1253 const char* attr_name, const void* proto,
1254 size_t proto_len, TF_Status* status) {
1255 // shape.ParseFromArray takes an int as length, this function takes size_t,
1256 // make sure there is no information loss.
1257 if (proto_len > std::numeric_limits<int>::max()) {
1258 status->status = InvalidArgument(
1259 "proto_len (", proto_len,
1260 " bytes) is too large to be parsed by the protocol buffer library");
1261 return;
1262 }
1263 TensorShapeProto shape;
1264 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
1265 desc->node_builder.Attr(attr_name, shape);
1266 status->status = Status::OK();
1267 } else {
1268 status->status = InvalidArgument("Unparseable TensorShapeProto");
1269 }
1270 }
1271
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)1272 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
1273 const char* attr_name,
1274 const void* const* protos,
1275 const size_t* proto_lens, int num_shapes,
1276 TF_Status* status) {
1277 std::vector<TensorShapeProto> shapes;
1278 shapes.resize(num_shapes);
1279 for (int i = 0; i < num_shapes; ++i) {
1280 if (proto_lens[i] > std::numeric_limits<int>::max()) {
1281 status->status = InvalidArgument(
1282 "length of element ", i, " in the list (", proto_lens[i],
1283 " bytes) is too large to be parsed by the protocol buffer library");
1284 return;
1285 }
1286 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
1287 status->status =
1288 InvalidArgument("Unparseable TensorShapeProto at index ", i);
1289 return;
1290 }
1291 }
1292 desc->node_builder.Attr(attr_name, shapes);
1293 status->status = Status::OK();
1294 }
1295
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1296 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1297 TF_Tensor* value, TF_Status* status) {
1298 Tensor t;
1299 status->status = TF_TensorToTensor(value, &t);
1300 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1301 }
1302
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1303 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1304 TF_Tensor* const* values, int num_values,
1305 TF_Status* status) {
1306 status->status = Status::OK();
1307 std::vector<Tensor> t;
1308 t.reserve(num_values);
1309
1310 for (int i = 0; i < num_values && status->status.ok(); ++i) {
1311 Tensor v;
1312 status->status = TF_TensorToTensor(values[i], &v);
1313 t.emplace_back(v);
1314 }
1315
1316 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1317 }
1318
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1319 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1320 const void* proto, size_t proto_len,
1321 TF_Status* status) {
1322 tensorflow::AttrValue attr_value;
1323 if (!attr_value.ParseFromArray(proto, proto_len)) {
1324 status->status = InvalidArgument("Unparseable AttrValue proto");
1325 return;
1326 }
1327
1328 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1329 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1330 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1331 status->status =
1332 InvalidArgument("Expected \"list\" field for \"",
1333 tensorflow::kColocationAttrName, "\" attribute");
1334 return;
1335 }
1336 desc->colocation_constraints.clear();
1337 for (const string& location : attr_value.list().s()) {
1338 desc->colocation_constraints.insert(location);
1339 }
1340 } else {
1341 desc->node_builder.Attr(attr_name, attr_value);
1342 }
1343
1344 status->status = Status::OK();
1345 }
1346
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1347 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1348 TF_Status* status)
1349 EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1350 Node* ret = nullptr;
1351
1352 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1353 status->status = InvalidArgument("Duplicate node name in graph: '",
1354 desc->node_builder.node_name(), "'");
1355 } else {
1356 if (!desc->colocation_constraints.empty()) {
1357 desc->node_builder.Attr(
1358 tensorflow::kColocationAttrName,
1359 std::vector<string>(desc->colocation_constraints.begin(),
1360 desc->colocation_constraints.end()));
1361 }
1362 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret);
1363
1364 if (status->status.ok()) {
1365 // Run shape inference function for newly added node.
1366 status->status = desc->graph->refiner.AddNode(ret);
1367 }
1368 if (status->status.ok()) {
1369 // Add the node to the name-to-node mapping.
1370 desc->graph->name_map[ret->name()] = ret;
1371 } else if (ret != nullptr) {
1372 desc->graph->graph.RemoveNode(ret);
1373 ret = nullptr;
1374 }
1375 }
1376
1377 delete desc;
1378
1379 return ToOperation(ret);
1380 }
1381
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1382 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1383 TF_Status* status) {
1384 mutex_lock l(desc->graph->mu);
1385 return TF_FinishOperationLocked(desc, status);
1386 }
1387
1388 // TF_Operation functions
1389 // ----------------------------------------------------------
1390
TF_OperationName(TF_Operation * oper)1391 const char* TF_OperationName(TF_Operation* oper) {
1392 return oper->node.name().c_str();
1393 }
1394
TF_OperationOpType(TF_Operation * oper)1395 const char* TF_OperationOpType(TF_Operation* oper) {
1396 return oper->node.type_string().c_str();
1397 }
1398
TF_OperationDevice(TF_Operation * oper)1399 const char* TF_OperationDevice(TF_Operation* oper) {
1400 return oper->node.requested_device().c_str();
1401 }
1402
TF_OperationNumOutputs(TF_Operation * oper)1403 int TF_OperationNumOutputs(TF_Operation* oper) {
1404 return oper->node.num_outputs();
1405 }
1406
TF_OperationOutputType(TF_Output oper_out)1407 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1408 return static_cast<TF_DataType>(
1409 oper_out.oper->node.output_type(oper_out.index));
1410 }
1411
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1412 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1413 TF_Status* status) {
1414 NameRangeMap name_ranges;
1415 status->status =
1416 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1417 if (!status->status.ok()) return -1;
1418 auto iter = name_ranges.find(arg_name);
1419 if (iter == name_ranges.end()) {
1420 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1421 return -1;
1422 }
1423 return iter->second.second - iter->second.first;
1424 }
1425
TF_OperationNumInputs(TF_Operation * oper)1426 int TF_OperationNumInputs(TF_Operation* oper) {
1427 return oper->node.num_inputs();
1428 }
1429
TF_OperationInputType(TF_Input oper_in)1430 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1431 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1432 }
1433
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1434 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1435 TF_Status* status) {
1436 NameRangeMap name_ranges;
1437 status->status =
1438 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1439 if (!status->status.ok()) return -1;
1440 auto iter = name_ranges.find(arg_name);
1441 if (iter == name_ranges.end()) {
1442 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1443 return -1;
1444 }
1445 return iter->second.second - iter->second.first;
1446 }
1447
TF_OperationInput(TF_Input oper_in)1448 TF_Output TF_OperationInput(TF_Input oper_in) {
1449 const tensorflow::Edge* edge;
1450 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1451 if (!s.ok()) {
1452 return {nullptr, -1};
1453 }
1454
1455 return {ToOperation(edge->src()), edge->src_output()};
1456 }
1457
TF_OperationOutputNumConsumers(TF_Output oper_out)1458 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1459 int count = 0;
1460 for (const auto* edge : oper_out.oper->node.out_edges()) {
1461 if (edge->src_output() == oper_out.index) {
1462 ++count;
1463 }
1464 }
1465 return count;
1466 }
1467
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1468 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1469 int max_consumers) {
1470 int count = 0;
1471 for (const auto* edge : oper_out.oper->node.out_edges()) {
1472 if (edge->src_output() == oper_out.index) {
1473 if (count < max_consumers) {
1474 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1475 }
1476 ++count;
1477 }
1478 }
1479 return count;
1480 }
1481
TF_OperationNumControlInputs(TF_Operation * oper)1482 int TF_OperationNumControlInputs(TF_Operation* oper) {
1483 int count = 0;
1484 for (const auto* edge : oper->node.in_edges()) {
1485 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1486 ++count;
1487 }
1488 }
1489 return count;
1490 }
1491
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1492 int TF_OperationGetControlInputs(TF_Operation* oper,
1493 TF_Operation** control_inputs,
1494 int max_control_inputs) {
1495 int count = 0;
1496 for (const auto* edge : oper->node.in_edges()) {
1497 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1498 if (count < max_control_inputs) {
1499 control_inputs[count] = ToOperation(edge->src());
1500 }
1501 ++count;
1502 }
1503 }
1504 return count;
1505 }
1506
TF_OperationNumControlOutputs(TF_Operation * oper)1507 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1508 int count = 0;
1509 for (const auto* edge : oper->node.out_edges()) {
1510 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1511 ++count;
1512 }
1513 }
1514 return count;
1515 }
1516
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1517 int TF_OperationGetControlOutputs(TF_Operation* oper,
1518 TF_Operation** control_outputs,
1519 int max_control_outputs) {
1520 int count = 0;
1521 for (const auto* edge : oper->node.out_edges()) {
1522 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1523 if (count < max_control_outputs) {
1524 control_outputs[count] = ToOperation(edge->dst());
1525 }
1526 ++count;
1527 }
1528 }
1529 return count;
1530 }
1531
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1532 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1533 const char* attr_name,
1534 TF_Status* status) {
1535 TF_AttrMetadata metadata;
1536 const auto* attr = GetAttrValue(oper, attr_name, status);
1537 if (!status->status.ok()) return metadata;
1538 switch (attr->value_case()) {
1539 #define SINGLE_CASE(kK, attr_type, size_expr) \
1540 case tensorflow::AttrValue::kK: \
1541 metadata.is_list = 0; \
1542 metadata.list_size = -1; \
1543 metadata.type = attr_type; \
1544 metadata.total_size = size_expr; \
1545 break;
1546
1547 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1548 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1549 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1550 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1551 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1552 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1553 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1554 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1555 #undef SINGLE_CASE
1556
1557 case tensorflow::AttrValue::kList:
1558 metadata.is_list = 1;
1559 metadata.list_size = 0;
1560 metadata.total_size = -1;
1561 #define LIST_CASE(field, attr_type, ...) \
1562 if (attr->list().field##_size() > 0) { \
1563 metadata.type = attr_type; \
1564 metadata.list_size = attr->list().field##_size(); \
1565 __VA_ARGS__; \
1566 break; \
1567 }
1568
1569 LIST_CASE(s, TF_ATTR_STRING, metadata.total_size = 0;
1570 for (int i = 0; i < attr->list().s_size();
1571 ++i) { metadata.total_size += attr->list().s(i).size(); });
1572 LIST_CASE(i, TF_ATTR_INT);
1573 LIST_CASE(f, TF_ATTR_FLOAT);
1574 LIST_CASE(b, TF_ATTR_BOOL);
1575 LIST_CASE(type, TF_ATTR_TYPE);
1576 LIST_CASE(shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1577 for (int i = 0; i < attr->list().shape_size(); ++i) {
1578 const auto& s = attr->list().shape(i);
1579 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1580 });
1581 LIST_CASE(tensor, TF_ATTR_TENSOR);
1582 LIST_CASE(tensor, TF_ATTR_FUNC);
1583 #undef LIST_CASE
1584 // All lists empty, determine the type from the OpDef.
1585 if (metadata.list_size == 0) {
1586 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1587 const auto& a = oper->node.op_def().attr(i);
1588 if (a.name().compare(attr_name) != 0) continue;
1589 const string& typestr = a.type();
1590 if (typestr == "list(string)") {
1591 metadata.type = TF_ATTR_STRING;
1592 } else if (typestr == "list(int)") {
1593 metadata.type = TF_ATTR_INT;
1594 } else if (typestr == "list(float)") {
1595 metadata.type = TF_ATTR_FLOAT;
1596 } else if (typestr == "list(bool)") {
1597 metadata.type = TF_ATTR_BOOL;
1598 } else if (typestr == "list(type)") {
1599 metadata.type = TF_ATTR_TYPE;
1600 } else if (typestr == "list(shape)") {
1601 metadata.type = TF_ATTR_SHAPE;
1602 } else if (typestr == "list(tensor)") {
1603 metadata.type = TF_ATTR_TENSOR;
1604 } else if (typestr == "list(func)") {
1605 metadata.type = TF_ATTR_FUNC;
1606 } else {
1607 status->status = InvalidArgument(
1608 "Attribute '", attr_name,
1609 "' has an empty value of an unrecognized type '", typestr, "'");
1610 return metadata;
1611 }
1612 }
1613 }
1614 break;
1615
1616 case tensorflow::AttrValue::kPlaceholder:
1617 metadata.is_list = 0;
1618 metadata.list_size = -1;
1619 metadata.type = TF_ATTR_PLACEHOLDER;
1620 metadata.total_size = -1;
1621 break;
1622
1623 case tensorflow::AttrValue::kFunc:
1624 metadata.is_list = 0;
1625 metadata.list_size = -1;
1626 metadata.type = TF_ATTR_FUNC;
1627 metadata.total_size = -1;
1628 break;
1629
1630 case tensorflow::AttrValue::VALUE_NOT_SET:
1631 status->status =
1632 InvalidArgument("Attribute '", attr_name, "' has no value set");
1633 break;
1634 }
1635 return metadata;
1636 }
1637
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1638 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1639 void* value, size_t max_length,
1640 TF_Status* status) {
1641 const auto* attr = GetAttrValue(oper, attr_name, status);
1642 if (!status->status.ok()) return;
1643 if (attr->value_case() != tensorflow::AttrValue::kS) {
1644 status->status =
1645 InvalidArgument("Attribute '", attr_name, "' is not a string");
1646 return;
1647 }
1648 if (max_length <= 0) {
1649 return;
1650 }
1651 const auto& s = attr->s();
1652 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1653 }
1654
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1655 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1656 void** values, size_t* lengths,
1657 int max_values, void* storage,
1658 size_t storage_size, TF_Status* status) {
1659 const auto* attr = GetAttrValue(oper, attr_name, status);
1660 if (!status->status.ok()) return;
1661 if (attr->value_case() != tensorflow::AttrValue::kList) {
1662 status->status =
1663 InvalidArgument("Value for '", attr_name, "' is not a list");
1664 return;
1665 }
1666 const auto len = std::min(max_values, attr->list().s_size());
1667 char* p = static_cast<char*>(storage);
1668 for (int i = 0; i < len; ++i) {
1669 const string& s = attr->list().s(i);
1670 values[i] = p;
1671 lengths[i] = s.size();
1672 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1673 status->status = InvalidArgument(
1674 "Not enough storage to hold the requested list of strings");
1675 return;
1676 }
1677 memcpy(values[i], s.data(), s.size());
1678 p += s.size();
1679 }
1680 }
1681
1682 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1683 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1684 TF_Status* status) { \
1685 cpp_type v; \
1686 status->status = \
1687 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1688 *value = static_cast<c_type>(v); \
1689 } \
1690 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1691 int max_values, TF_Status* status) { \
1692 const auto* attr = GetAttrValue(oper, attr_name, status); \
1693 if (!status->status.ok()) return; \
1694 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1695 status->status = \
1696 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1697 return; \
1698 } \
1699 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1700 for (int i = 0; i < len; ++i) { \
1701 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1702 } \
1703 }
1704 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1705 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1706 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1707 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1708 #undef DEFINE_GETATTR
1709
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1710 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1711 int64_t* value, int num_dims, TF_Status* status) {
1712 PartialTensorShape shape;
1713 status->status =
1714 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1715 if (!status->status.ok()) return;
1716 auto len = std::min(shape.dims(), num_dims);
1717 for (int i = 0; i < len; ++i) {
1718 value[i] = shape.dim_size(i);
1719 }
1720 }
1721
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** values,int * num_dims,int max_values,int64_t * storage,int storage_size,TF_Status * status)1722 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1723 int64_t** values, int* num_dims,
1724 int max_values, int64_t* storage,
1725 int storage_size, TF_Status* status) {
1726 std::vector<PartialTensorShape> shapes;
1727 status->status =
1728 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1729 if (!status->status.ok()) return;
1730 auto len = std::min(static_cast<int>(shapes.size()), max_values);
1731 int64_t* p = storage;
1732 int storage_left = storage_size;
1733 for (int i = 0; i < len; ++i) {
1734 // shapes[i].dims() == -1 for shapes with an unknown rank.
1735 int64_t n = shapes[i].dims();
1736 num_dims[i] = n;
1737 values[i] = p;
1738 if (n < 0) {
1739 continue;
1740 }
1741 if (storage_left < n) {
1742 status->status = InvalidArgument(
1743 "Not enough storage to hold the requested list of shapes");
1744 return;
1745 }
1746 storage_left -= n;
1747 for (int j = 0; j < n; ++j, ++p) {
1748 *p = shapes[i].dim_size(j);
1749 }
1750 }
1751 }
1752
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1753 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1754 const char* attr_name,
1755 TF_Buffer* value, TF_Status* status) {
1756 const auto* attr = GetAttrValue(oper, attr_name, status);
1757 if (!status->status.ok()) return;
1758 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1759 status->status =
1760 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1761 return;
1762 }
1763 status->status = MessageToBuffer(attr->shape(), value);
1764 }
1765
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1766 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1767 const char* attr_name,
1768 TF_Buffer** values, int max_values,
1769 TF_Status* status) {
1770 const auto* attr = GetAttrValue(oper, attr_name, status);
1771 if (!status->status.ok()) return;
1772 if (attr->value_case() != tensorflow::AttrValue::kList) {
1773 status->status =
1774 InvalidArgument("Value for '", attr_name, "' is not a list");
1775 return;
1776 }
1777 const auto len = std::min(max_values, attr->list().shape_size());
1778 for (int i = 0; i < len; ++i) {
1779 values[i] = TF_NewBuffer();
1780 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1781 if (!status->status.ok()) {
1782 // Delete everything allocated to far, the operation has failed.
1783 for (int j = 0; j <= i; ++j) {
1784 TF_DeleteBuffer(values[j]);
1785 }
1786 return;
1787 }
1788 }
1789 }
1790
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1791 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1792 TF_Tensor** value, TF_Status* status) {
1793 *value = nullptr;
1794 Tensor t;
1795 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1796 if (!status->status.ok()) return;
1797 *value = TF_TensorFromTensor(t, status);
1798 }
1799
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1800 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1801 TF_Tensor** values, int max_values,
1802 TF_Status* status) {
1803 std::vector<Tensor> ts;
1804 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1805 if (!status->status.ok()) return;
1806 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1807 for (int i = 0; i < len; ++i) {
1808 values[i] = TF_TensorFromTensor(ts[i], status);
1809 }
1810 }
1811
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1812 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1813 TF_Buffer* output_attr_value,
1814 TF_Status* status) {
1815 const auto* attr = GetAttrValue(oper, attr_name, status);
1816 if (!status->status.ok()) return;
1817 status->status = MessageToBuffer(*attr, output_attr_value);
1818 }
1819
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1820 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1821 TF_Status* status) {
1822 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1823 }
1824
1825 // TF_Graph functions ---------------------------------------------------------
1826
TF_Graph()1827 TF_Graph::TF_Graph()
1828 : graph(tensorflow::OpRegistry::Global()),
1829 refiner(graph.versions().producer(), graph.op_registry()),
1830 delete_requested(false),
1831 parent(nullptr),
1832 parent_inputs(nullptr) {}
1833
TF_NewGraph()1834 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1835
TF_DeleteGraph(TF_Graph * g)1836 void TF_DeleteGraph(TF_Graph* g) {
1837 g->mu.lock();
1838 g->delete_requested = true;
1839 const bool del = g->sessions.empty();
1840 g->mu.unlock();
1841 if (del) delete g;
1842 }
1843
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1844 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1845 mutex_lock l(graph->mu);
1846 auto iter = graph->name_map.find(oper_name);
1847 if (iter == graph->name_map.end()) {
1848 return nullptr;
1849 } else {
1850 return ToOperation(iter->second);
1851 }
1852 }
1853
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1854 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1855 if (*pos == 0) {
1856 // Advance past the first sentinel nodes in every graph (the source & sink).
1857 *pos += 2;
1858 } else {
1859 // Advance to the next node.
1860 *pos += 1;
1861 }
1862
1863 mutex_lock l(graph->mu);
1864 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1865 Node* node = graph->graph.FindNodeId(*pos);
1866 // FindNodeId() returns nullptr for nodes that have been deleted.
1867 // We aren't currently allowing nodes to be deleted, but it is safer
1868 // to still check.
1869 if (node != nullptr) return ToOperation(node);
1870 *pos += 1;
1871 }
1872
1873 // No more nodes.
1874 return nullptr;
1875 }
1876
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1877 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1878 TF_Status* status) {
1879 GraphDef def;
1880 {
1881 mutex_lock l(graph->mu);
1882 graph->graph.ToGraphDef(&def);
1883 }
1884 status->status = MessageToBuffer(def, output_graph_def);
1885 }
1886
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1887 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1888 TF_Buffer* output_op_def, TF_Status* status) {
1889 const OpDef* op_def;
1890 {
1891 mutex_lock l(graph->mu);
1892 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1893 if (!status->status.ok()) return;
1894 }
1895 status->status = MessageToBuffer(*op_def, output_op_def);
1896 }
1897
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1898 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1899 TF_Status* status) {
1900 VersionDef versions;
1901 {
1902 mutex_lock l(graph->mu);
1903 versions = graph->graph.versions();
1904 }
1905 status->status = MessageToBuffer(versions, output_version_def);
1906 }
1907
TF_NewImportGraphDefOptions()1908 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1909 return new TF_ImportGraphDefOptions;
1910 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1911 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1912 delete opts;
1913 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1914 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1915 const char* prefix) {
1916 opts->opts.prefix = prefix;
1917 }
1918
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1919 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1920 unsigned char uniquify_names) {
1921 opts->opts.uniquify_names = uniquify_names;
1922 }
1923
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1924 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1925 unsigned char uniquify_prefix) {
1926 opts->opts.uniquify_prefix = uniquify_prefix;
1927 }
1928
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1929 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1930 const char* src_name,
1931 int src_index, TF_Output dst) {
1932 opts->tensor_id_data.push_back(src_name);
1933 const string& src_name_str = opts->tensor_id_data.back();
1934 // We don't need to store dst's name in tensor_id_data, since `dst` must
1935 // outlive the ImportGraphDef call.
1936 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1937 }
1938
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1939 void TF_ImportGraphDefOptionsRemapControlDependency(
1940 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1941 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1942 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1943 }
1944
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1945 extern void TF_ImportGraphDefOptionsAddControlDependency(
1946 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1947 opts->opts.control_dependencies.push_back(oper->node.name());
1948 }
1949
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1950 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1951 const char* oper_name, int index) {
1952 opts->tensor_id_data.push_back(oper_name);
1953 const string& oper_name_str = opts->tensor_id_data.back();
1954 opts->opts.return_tensors.emplace_back(oper_name_str, index);
1955 }
1956
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1957 int TF_ImportGraphDefOptionsNumReturnOutputs(
1958 const TF_ImportGraphDefOptions* opts) {
1959 return opts->opts.return_tensors.size();
1960 }
1961
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1962 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1963 const char* oper_name) {
1964 opts->opts.return_nodes.push_back(oper_name);
1965 }
1966
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1967 int TF_ImportGraphDefOptionsNumReturnOperations(
1968 const TF_ImportGraphDefOptions* opts) {
1969 return opts->opts.return_nodes.size();
1970 }
1971
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1972 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1973 int* num_outputs,
1974 TF_Output** outputs) {
1975 *num_outputs = results->return_tensors.size();
1976 *outputs = results->return_tensors.data();
1977 }
1978
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1979 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1980 int* num_opers,
1981 TF_Operation*** opers) {
1982 *num_opers = results->return_nodes.size();
1983 *opers = results->return_nodes.data();
1984 }
1985
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1986 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1987 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1988 const char*** src_names, int** src_indexes) {
1989 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1990 *src_names = results->missing_unused_key_names.data();
1991 *src_indexes = results->missing_unused_key_indexes.data();
1992 }
1993
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1994 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1995 delete results;
1996 }
1997
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1998 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1999 const TF_ImportGraphDefOptions* opts,
2000 TF_ImportGraphDefResults* tf_results,
2001 TF_Status* status)
2002 EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
2003 const int last_node_id = graph->graph.num_node_ids();
2004 tensorflow::ImportGraphDefResults results;
2005 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
2006 &graph->refiner, &results);
2007 if (!status->status.ok()) return;
2008
2009 // Add new nodes to name_map
2010 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
2011 auto* node = graph->graph.FindNodeId(i);
2012 if (node != nullptr) graph->name_map[node->name()] = node;
2013 }
2014
2015 // Populate return_tensors
2016 DCHECK(tf_results->return_tensors.empty());
2017 tf_results->return_tensors.resize(results.return_tensors.size());
2018 for (int i = 0; i < results.return_tensors.size(); ++i) {
2019 tf_results->return_tensors[i].oper =
2020 ToOperation(results.return_tensors[i].first);
2021 tf_results->return_tensors[i].index = results.return_tensors[i].second;
2022 }
2023
2024 // Populate return_nodes
2025 DCHECK(tf_results->return_nodes.empty());
2026 tf_results->return_nodes.resize(results.return_nodes.size());
2027 for (int i = 0; i < results.return_nodes.size(); ++i) {
2028 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
2029 }
2030
2031 // Populate missing unused map keys
2032 DCHECK(tf_results->missing_unused_key_names.empty());
2033 DCHECK(tf_results->missing_unused_key_indexes.empty());
2034 DCHECK(tf_results->missing_unused_key_names_data.empty());
2035
2036 size_t size = results.missing_unused_input_map_keys.size();
2037 tf_results->missing_unused_key_names.resize(size);
2038 tf_results->missing_unused_key_indexes.resize(size);
2039
2040 for (int i = 0; i < size; ++i) {
2041 TensorId id = results.missing_unused_input_map_keys[i];
2042 tf_results->missing_unused_key_names_data.push_back(id.first.ToString());
2043 tf_results->missing_unused_key_names[i] =
2044 tf_results->missing_unused_key_names_data.back().c_str();
2045 tf_results->missing_unused_key_indexes[i] = id.second;
2046 }
2047 }
2048
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2049 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
2050 TF_Graph* graph, const TF_Buffer* graph_def,
2051 const TF_ImportGraphDefOptions* options, TF_Status* status) {
2052 GraphDef def;
2053 if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2054 status->status = InvalidArgument("Invalid GraphDef");
2055 return nullptr;
2056 }
2057 auto results = new TF_ImportGraphDefResults();
2058 mutex_lock l(graph->mu);
2059 GraphImportGraphDefLocked(graph, def, options, results, status);
2060 if (!status->status.ok()) {
2061 delete results;
2062 return nullptr;
2063 }
2064 return results;
2065 }
2066
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)2067 void TF_GraphImportGraphDefWithReturnOutputs(
2068 TF_Graph* graph, const TF_Buffer* graph_def,
2069 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
2070 int num_return_outputs, TF_Status* status) {
2071 if (num_return_outputs != options->opts.return_tensors.size()) {
2072 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
2073 options->opts.return_tensors.size(),
2074 ", got ", num_return_outputs);
2075 return;
2076 }
2077 if (num_return_outputs > 0 && return_outputs == nullptr) {
2078 status->status = InvalidArgument(
2079 "'return_outputs' must be preallocated to length ", num_return_outputs);
2080 return;
2081 }
2082 GraphDef def;
2083 if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
2084 status->status = InvalidArgument("Invalid GraphDef");
2085 return;
2086 }
2087 TF_ImportGraphDefResults results;
2088 mutex_lock l(graph->mu);
2089 GraphImportGraphDefLocked(graph, def, options, &results, status);
2090 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
2091 memcpy(return_outputs, results.return_tensors.data(),
2092 num_return_outputs * sizeof(TF_Output));
2093 }
2094
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)2095 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
2096 const TF_ImportGraphDefOptions* options,
2097 TF_Status* status) {
2098 TF_ImportGraphDefResults* results =
2099 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
2100 TF_DeleteImportGraphDefResults(results);
2101 }
2102
2103 // While loop functions -------------------------------------------------------
2104
2105 namespace {
2106
2107 #ifndef __ANDROID__
2108
2109 // Creates a placeholder representing an input to the cond or body graph.
2110 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)2111 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
2112 TF_Output* input, TF_Status* status) {
2113 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
2114 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
2115 // TODO(skyewm): set placeholder shape
2116 TF_Operation* oper = TF_FinishOperation(desc, status);
2117 if (!status->status.ok()) return false;
2118 *input = {oper, 0};
2119 return true;
2120 }
2121
2122 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
2123 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
2124 // will be prepended to copied node names. `control_deps` are nodes in
2125 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
2126 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
2127 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)2128 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
2129 tensorflow::ShapeRefiner* dst_refiner,
2130 const TF_Output* src_inputs,
2131 const std::vector<tensorflow::Output>& dst_inputs,
2132 const string& prefix,
2133 const std::vector<tensorflow::Operation>& control_deps,
2134 const TF_Output* nodes_to_return, int nreturn_nodes,
2135 std::vector<tensorflow::Output>* return_nodes) {
2136 DCHECK(return_nodes != nullptr);
2137 GraphDef gdef;
2138 src_graph->ToGraphDef(&gdef);
2139
2140 tensorflow::ImportGraphDefOptions opts;
2141 opts.prefix = prefix;
2142
2143 for (int i = 0; i < dst_inputs.size(); ++i) {
2144 opts.input_map[ToTensorId(src_inputs[i])] =
2145 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
2146 }
2147 opts.skip_mapped_nodes = true;
2148
2149 for (const tensorflow::Operation& op : control_deps) {
2150 opts.control_dependencies.push_back(op.node()->name());
2151 }
2152
2153 for (int i = 0; i < nreturn_nodes; ++i) {
2154 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
2155 }
2156
2157 // TODO(skyewm): change to OutputTensor
2158 tensorflow::ImportGraphDefResults results;
2159 TF_RETURN_IF_ERROR(
2160 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
2161
2162 for (const auto& pair : results.return_tensors) {
2163 return_nodes->emplace_back(pair.first, pair.second);
2164 }
2165 return Status::OK();
2166 }
2167
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)2168 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
2169 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
2170 params.cond_graph->parent == nullptr ||
2171 params.cond_graph->parent != params.body_graph->parent ||
2172 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
2173 params.ninputs <= 0 || params.cond_inputs == nullptr ||
2174 params.body_inputs == nullptr || params.body_outputs == nullptr) {
2175 s->status = InvalidArgument(
2176 "TF_WhileParams must be created by successful TF_NewWhile() call");
2177 return false;
2178 }
2179 return true;
2180 }
2181
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)2182 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
2183 if (params.cond_output.oper == nullptr) {
2184 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
2185 return false;
2186 }
2187 for (int i = 0; i < params.ninputs; ++i) {
2188 if (params.body_outputs[i].oper == nullptr) {
2189 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
2190 "field isn't set");
2191 return false;
2192 }
2193 }
2194 if (params.name == nullptr) {
2195 s->status = InvalidArgument("TF_WhileParams `name` field is null");
2196 return false;
2197 }
2198 return true;
2199 }
2200
2201 #endif // __ANDROID__
2202
FreeWhileResources(const TF_WhileParams * params)2203 void FreeWhileResources(const TF_WhileParams* params) {
2204 TF_DeleteGraph(params->cond_graph);
2205 TF_DeleteGraph(params->body_graph);
2206 delete[] params->cond_inputs;
2207 delete[] params->body_inputs;
2208 delete[] params->body_outputs;
2209 }
2210
EmptyWhileParams()2211 TF_WhileParams EmptyWhileParams() {
2212 return {0, nullptr, nullptr, {nullptr, 0},
2213 nullptr, nullptr, nullptr, nullptr};
2214 }
2215
2216 } // namespace
2217
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)2218 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
2219 TF_Status* status) {
2220 #ifdef __ANDROID__
2221 status->status = tensorflow::errors::Unimplemented(
2222 "Creating while loops is not supported in Android. File a bug at "
2223 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2224 "important to you");
2225 return EmptyWhileParams();
2226 #else
2227 if (ninputs == 0) {
2228 status->status =
2229 InvalidArgument("TF_NewWhile() must be passed at least one input");
2230 return EmptyWhileParams();
2231 }
2232
2233 TF_Graph* cond_graph = TF_NewGraph();
2234 TF_Graph* body_graph = TF_NewGraph();
2235 cond_graph->parent = g;
2236 cond_graph->parent_inputs = inputs;
2237 body_graph->parent = g;
2238 body_graph->parent_inputs = inputs;
2239
2240 TF_Output* cond_inputs = new TF_Output[ninputs];
2241 TF_Output cond_output = {nullptr, -1};
2242 TF_Output* body_inputs = new TF_Output[ninputs];
2243 TF_Output* body_outputs = new TF_Output[ninputs];
2244 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
2245 const char* name = nullptr;
2246
2247 for (int i = 0; i < ninputs; ++i) {
2248 // TODO(skyewm): prefix names with underscore (requires some plumbing)
2249 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
2250 &cond_inputs[i], status)) {
2251 break;
2252 }
2253 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
2254 &body_inputs[i], status)) {
2255 break;
2256 }
2257 }
2258
2259 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
2260 body_graph, body_inputs, body_outputs, name};
2261
2262 if (!status->status.ok()) {
2263 FreeWhileResources(¶ms);
2264 return EmptyWhileParams();
2265 }
2266 return params;
2267 #endif // __ANDROID__
2268 }
2269
2270 #ifndef __ANDROID__
2271 namespace {
2272
2273 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2274 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2275 TF_Output* outputs) {
2276 if (!ValidateInputWhileParams(*params, status)) return;
2277
2278 TF_Graph* parent = params->cond_graph->parent;
2279 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2280 int num_loop_vars = params->ninputs;
2281
2282 mutex_lock l(parent->mu);
2283
2284 // 'cond_fn' copies the cond graph into the parent graph.
2285 tensorflow::ops::CondGraphBuilderFn cond_fn =
2286 [params, parent](const tensorflow::Scope& scope,
2287 const std::vector<tensorflow::Output>& inputs,
2288 tensorflow::Output* output) {
2289 DCHECK_EQ(scope.graph(), &parent->graph);
2290 std::vector<tensorflow::Output> cond_output;
2291 TF_RETURN_IF_ERROR(CopyGraph(
2292 ¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
2293 params->cond_inputs, inputs, scope.impl()->name(),
2294 scope.impl()->control_deps(), ¶ms->cond_output,
2295 /* nreturn_nodes */ 1, &cond_output));
2296 *output = cond_output[0];
2297 return Status::OK();
2298 };
2299
2300 // 'body_fn' copies the body graph into the parent graph.
2301 tensorflow::ops::BodyGraphBuilderFn body_fn =
2302 [params, parent, num_loop_vars](
2303 const tensorflow::Scope& scope,
2304 const std::vector<tensorflow::Output>& inputs,
2305 std::vector<tensorflow::Output>* outputs) {
2306 DCHECK_EQ(scope.graph(), &parent->graph);
2307 TF_RETURN_IF_ERROR(
2308 CopyGraph(¶ms->body_graph->graph, &parent->graph,
2309 &parent->refiner, params->body_inputs, inputs,
2310 scope.impl()->name(), scope.impl()->control_deps(),
2311 params->body_outputs, num_loop_vars, outputs));
2312 return Status::OK();
2313 };
2314
2315 // Create the while loop using an internal scope.
2316 tensorflow::Scope scope =
2317 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2318 .NewSubScope(params->name);
2319
2320 const int first_new_node_id = parent->graph.num_node_ids();
2321
2322 tensorflow::OutputList loop_outputs;
2323 status->status = tensorflow::ops::BuildWhileLoop(
2324 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2325 body_fn, params->name, &loop_outputs);
2326
2327 // Update name_map with newly-created ops.
2328 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2329 // a bad status. Once we fix this, we may want to return early instead of
2330 // executing the following code.
2331 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2332 Node* new_node = parent->graph.FindNodeId(i);
2333 if (new_node == nullptr) continue;
2334 parent->name_map[new_node->name()] = new_node;
2335 }
2336
2337 // Populate 'outputs'.
2338 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2339 for (int i = 0; i < loop_outputs.size(); ++i) {
2340 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2341 }
2342 }
2343
2344 } // namespace
2345 #endif // __ANDROID__
2346
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2347 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2348 TF_Output* outputs) {
2349 #ifdef __ANDROID__
2350 status->status = tensorflow::errors::Unimplemented(
2351 "Creating while loops is not supported in Android. File a bug at "
2352 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2353 "important to you");
2354 #else
2355 // If it appears the caller created or modified `params`, don't free resources
2356 if (!ValidateConstWhileParams(*params, status)) return;
2357 TF_FinishWhileHelper(params, status, outputs);
2358 FreeWhileResources(params);
2359 #endif // __ANDROID__
2360 }
2361
TF_AbortWhile(const TF_WhileParams * params)2362 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2363
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2364 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2365 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2366 #ifdef __ANDROID__
2367 status->status = tensorflow::errors::Unimplemented(
2368 "Adding gradients is not supported in Android. File a bug at "
2369 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2370 "important to you");
2371 #else
2372 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2373 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2374 std::vector<tensorflow::Output> dy_arg;
2375
2376 {
2377 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2378 mutex_lock graph_lock(g->mu);
2379
2380 const int first_new_node_id = g->graph.num_node_ids();
2381
2382 tensorflow::Scope scope =
2383 NewInternalScope(&g->graph, &status->status, &g->refiner)
2384 .NewSubScope("gradients");
2385
2386 if (dx != nullptr) {
2387 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2388 status->status =
2389 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2390 } else {
2391 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2392 }
2393
2394 // Update g->name_map with the name_map from the scope, which will contain
2395 // the new gradient ops.
2396 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2397 Node* n = g->graph.FindNodeId(i);
2398 if (n == nullptr) continue;
2399 g->name_map[n->name()] = n;
2400 }
2401 }
2402
2403 // Unpack the results from grad_outputs_arg.
2404 TFOutputsFromOutputs(dy_arg, dy);
2405 #endif // __ANDROID__
2406 }
2407
2408 // TF_Session functions ----------------------------------------------
2409
TF_Session(tensorflow::Session * s,TF_Graph * g)2410 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2411 : session(s), graph(g), last_num_graph_nodes(0), device_mgr(nullptr) {
2412 if (s->LocalDeviceManager(&device_mgr).ok()) {
2413 devices = device_mgr->ListDevices();
2414 }
2415 }
2416
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2417 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2418 TF_Status* status) {
2419 Session* session;
2420 status->status = NewSession(opt->options, &session);
2421 if (status->status.ok()) {
2422 TF_Session* new_session = new TF_Session(session, graph);
2423 if (graph != nullptr) {
2424 mutex_lock l(graph->mu);
2425 graph->sessions[new_session] = Status::OK();
2426 }
2427 return new_session;
2428 } else {
2429 DCHECK_EQ(nullptr, session);
2430 return nullptr;
2431 }
2432 }
2433
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2434 TF_Session* TF_LoadSessionFromSavedModel(
2435 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2436 const char* export_dir, const char* const* tags, int tags_len,
2437 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2438 // TODO(ashankar): Remove the __ANDROID__ guard. This will require ensuring that
2439 // the tensorflow/cc/saved_model:loader build target is Android friendly.
2440 #ifdef __ANDROID__
2441 status->status = tensorflow::errors::Unimplemented(
2442 "Loading a SavedModel is not supported in Android. File a bug at "
2443 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2444 "important to you");
2445 return nullptr;
2446 #else
2447 mutex_lock l(graph->mu);
2448 if (!graph->name_map.empty()) {
2449 status->status = InvalidArgument("Graph is non-empty.");
2450 return nullptr;
2451 }
2452
2453 RunOptions run_options_proto;
2454 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2455 run_options->data, run_options->length)) {
2456 status->status = InvalidArgument("Unparseable RunOptions proto");
2457 return nullptr;
2458 }
2459
2460 std::unordered_set<string> tag_set;
2461 for (int i = 0; i < tags_len; i++) {
2462 tag_set.insert(string(tags[i]));
2463 }
2464
2465 tensorflow::SavedModelBundle bundle;
2466 status->status =
2467 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2468 export_dir, tag_set, &bundle);
2469 if (!status->status.ok()) return nullptr;
2470
2471 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2472 // extends using GraphDefs. The Graph instance is different, but equivalent
2473 // to the one used to create the session.
2474 //
2475 // TODO(jhseu): When Session is modified to take Graphs instead of
2476 // GraphDefs, return the Graph generated in LoadSavedModel().
2477 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2478 TF_ImportGraphDefResults results;
2479 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2480 import_opts, &results, status);
2481 TF_DeleteImportGraphDefOptions(import_opts);
2482 if (TF_GetCode(status) != TF_OK) return nullptr;
2483
2484 if (meta_graph_def != nullptr) {
2485 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2486 if (!status->status.ok()) return nullptr;
2487 }
2488
2489 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2490
2491 graph->sessions[session] = Status::OK();
2492 session->last_num_graph_nodes = graph->graph.num_node_ids();
2493 return session;
2494 #endif // __ANDROID__
2495 }
2496
TF_CloseSession(TF_Session * s,TF_Status * status)2497 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2498 status->status = s->session->Close();
2499 }
2500
TF_DeleteSession(TF_Session * s,TF_Status * status)2501 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2502 status->status = Status::OK();
2503 TF_Graph* const graph = s->graph;
2504 if (graph != nullptr) {
2505 graph->mu.lock();
2506 graph->sessions.erase(s);
2507 const bool del = graph->delete_requested && graph->sessions.empty();
2508 graph->mu.unlock();
2509 if (del) delete graph;
2510 }
2511 delete s->session;
2512 delete s;
2513 }
2514
2515 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2516 // directly, instead of requiring us to serialize to a GraphDef and
2517 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)2518 static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
2519 if (session->graph != nullptr) {
2520 mutex_lock session_lock(session->mu);
2521 session->graph->mu.lock();
2522 const Graph& graph = session->graph->graph;
2523
2524 status->status = session->graph->sessions[session];
2525 if (!status->status.ok()) {
2526 session->graph->mu.unlock();
2527 return false;
2528 }
2529
2530 const auto num_nodes = graph.num_node_ids();
2531 if (session->last_num_graph_nodes < num_nodes) {
2532 status->status = tensorflow::ValidateNoCycles(session->graph->graph);
2533 if (!status->status.ok()) {
2534 session->graph->mu.unlock();
2535 return false;
2536 }
2537
2538 GraphDef graph_def;
2539 *graph_def.mutable_versions() = graph.versions();
2540 // Fill graph_def with nodes with ids in the range
2541 // [session->last_num_graph_nodes, num_nodes), that is the nodes
2542 // added since the last TF_SessionRun() call.
2543 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
2544 Node* const node = graph.FindNodeId(id);
2545 if (node != nullptr && node->IsOp()) {
2546 NodeDef* const node_def = graph_def.add_node();
2547 *node_def = node->def();
2548 }
2549 }
2550 *graph_def.mutable_library() = graph.flib_def().ToProto();
2551 session->graph->mu.unlock();
2552 status->status = session->session->Extend(graph_def);
2553 if (!status->status.ok()) {
2554 // Contract is we always delete input_values[i].
2555 return false;
2556 }
2557 // Note: session->session is not modified if Extend() fails, so
2558 // we only set last_num_graph_nodes if it succeeds.
2559 session->last_num_graph_nodes = num_nodes;
2560 } else {
2561 session->graph->mu.unlock();
2562 }
2563 }
2564 return true;
2565 }
2566
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2567 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2568 const TF_Output* inputs, TF_Tensor* const* input_values,
2569 int ninputs, const TF_Output* outputs,
2570 TF_Tensor** output_values, int noutputs,
2571 const TF_Operation* const* target_opers, int ntargets,
2572 TF_Buffer* run_metadata, TF_Status* status) {
2573 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2574 // directly, instead of requiring us to serialize to a GraphDef and
2575 // call Session::Extend().
2576 if (!ExtendSessionGraphHelper(session, status)) {
2577 return;
2578 }
2579
2580 TF_Run_Setup(noutputs, output_values, status);
2581
2582 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2583 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2584 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2585 for (int i = 0; i < ninputs; ++i) {
2586 input_pairs[i].first = OutputName(inputs[i]);
2587 }
2588
2589 // Convert from TF_Output to string names.
2590 std::vector<string> output_names(noutputs);
2591 for (int i = 0; i < noutputs; ++i) {
2592 output_names[i] = OutputName(outputs[i]);
2593 }
2594
2595 // Convert from TF_Operation* to string names.
2596 std::vector<string> target_names(ntargets);
2597 for (int i = 0; i < ntargets; ++i) {
2598 target_names[i] = target_opers[i]->node.name();
2599 }
2600
2601 // Actually run.
2602 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2603 output_names, output_values, target_names, run_metadata,
2604 status);
2605 }
2606
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2607 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2608 int ninputs, const TF_Output* outputs, int noutputs,
2609 const TF_Operation* const* target_opers, int ntargets,
2610 const char** handle, TF_Status* status) {
2611 *handle = nullptr;
2612
2613 if (!ExtendSessionGraphHelper(session, status)) {
2614 return;
2615 }
2616
2617 std::vector<string> input_names(ninputs);
2618 for (int i = 0; i < ninputs; ++i) {
2619 input_names[i] = OutputName(inputs[i]);
2620 }
2621
2622 std::vector<string> output_names(noutputs);
2623 for (int i = 0; i < noutputs; ++i) {
2624 output_names[i] = OutputName(outputs[i]);
2625 }
2626
2627 std::vector<string> target_names(ntargets);
2628 for (int i = 0; i < ntargets; ++i) {
2629 target_names[i] = target_opers[i]->node.name();
2630 }
2631
2632 string new_handle;
2633 status->status = session->session->PRunSetup(input_names, output_names,
2634 target_names, &new_handle);
2635 if (status->status.ok()) {
2636 char* buf = new char[new_handle.size() + 1];
2637 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2638 *handle = buf;
2639 }
2640 }
2641
TF_DeletePRunHandle(const char * handle)2642 void TF_DeletePRunHandle(const char* handle) {
2643 delete[] handle;
2644 // TODO(suharshs): Free up any resources held by the partial run state.
2645 }
2646
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2647 void TF_SessionPRun(TF_Session* session, const char* handle,
2648 const TF_Output* inputs, TF_Tensor* const* input_values,
2649 int ninputs, const TF_Output* outputs,
2650 TF_Tensor** output_values, int noutputs,
2651 const TF_Operation* const* target_opers, int ntargets,
2652 TF_Status* status) {
2653 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2654 // directly, instead of requiring us to serialize to a GraphDef and
2655 // call Session::Extend().
2656 if (!ExtendSessionGraphHelper(session, status)) {
2657 return;
2658 }
2659
2660 TF_Run_Setup(noutputs, output_values, status);
2661
2662 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2663 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2664 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2665 for (int i = 0; i < ninputs; ++i) {
2666 input_pairs[i].first = OutputName(inputs[i]);
2667 }
2668
2669 // Convert from TF_Output to string names.
2670 std::vector<string> output_names(noutputs);
2671 for (int i = 0; i < noutputs; ++i) {
2672 output_names[i] = OutputName(outputs[i]);
2673 }
2674
2675 // Convert from TF_Operation* to string names.
2676 std::vector<string> target_names(ntargets);
2677 for (int i = 0; i < ntargets; ++i) {
2678 target_names[i] = target_opers[i]->node.name();
2679 }
2680
2681 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2682 output_values, target_names, nullptr, status);
2683 }
2684
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2685 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2686 tensorflow::OpList op_list;
2687 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2688 status->status = InvalidArgument("Unparseable OpList");
2689 return nullptr;
2690 }
2691 status->status = Status::OK();
2692 return new TF_ApiDefMap(op_list);
2693 }
2694
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2695 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2696
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2697 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2698 size_t text_len, TF_Status* status) {
2699 #ifdef __ANDROID__
2700 status->status = tensorflow::errors::Unimplemented(
2701 "ApiDefMap is not supported in Android.");
2702 #else
2703 mutex_lock l(api_def_map->lock);
2704 if (api_def_map->update_docs_called) {
2705 status->status = FailedPrecondition(
2706 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2707 "called.");
2708 return;
2709 }
2710 string api_def_text(text, text_len);
2711 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2712 #endif // __ANDROID__
2713 }
2714
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2715 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2716 size_t name_len, TF_Status* status) {
2717 #ifdef __ANDROID__
2718 status->status = tensorflow::errors::Unimplemented(
2719 "ApiDefMap is not supported in Android.");
2720 return nullptr;
2721 #else
2722 mutex_lock l(api_def_map->lock);
2723 if (!api_def_map->update_docs_called) {
2724 api_def_map->api_def_map.UpdateDocs();
2725 api_def_map->update_docs_called = true;
2726 }
2727 string name_str(name, name_len);
2728 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2729
2730 TF_Buffer* ret = TF_NewBuffer();
2731 status->status = MessageToBuffer(*api_def, ret);
2732 return ret;
2733 #endif // __ANDROID__
2734 }
2735 } // end extern "C"
2736