1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h" // NOLINT
26
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
29 #include "tensorflow/cc/framework/gradients.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/cc/framework/scope_internal.h"
32 #include "tensorflow/cc/ops/while_loop.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/framework/logging.h"
36 #include "tensorflow/core/framework/op_gen_lib.h"
37 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 #include "tensorflow/c/c_api_internal.h"
39 #include "tensorflow/c/tf_status_internal.h"
40 #include "tensorflow/c/tf_tensor.h"
41 #include "tensorflow/core/common_runtime/device_mgr.h"
42 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/common_runtime/shape_refiner.h"
45 #include "tensorflow/core/framework/allocation_description.pb.h"
46 #include "tensorflow/core/framework/kernel_def.pb.h"
47 #include "tensorflow/core/framework/log_memory.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/partial_tensor_shape.h"
51 #include "tensorflow/core/framework/tensor.h"
52 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/framework/versions.pb.h"
57 #include "tensorflow/core/graph/graph.h"
58 #include "tensorflow/core/graph/node_builder.h"
59 #include "tensorflow/core/graph/validate.h"
60 #include "tensorflow/core/lib/gtl/array_slice.h"
61 #include "tensorflow/core/platform/coding.h"
62 #include "tensorflow/core/platform/errors.h"
63 #include "tensorflow/core/platform/mem.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/protobuf.h"
66 #include "tensorflow/core/platform/status.h"
67 #include "tensorflow/core/platform/str_util.h"
68 #include "tensorflow/core/platform/strcat.h"
69 #include "tensorflow/core/platform/stringpiece.h"
70 #include "tensorflow/core/platform/thread_annotations.h"
71 #include "tensorflow/core/platform/types.h"
72 #include "tensorflow/core/public/session.h"
73 #include "tensorflow/core/public/version.h"
74
75 // The implementation below is at the top level instead of the
76 // brain namespace because we are defining 'extern "C"' functions.
77 using tensorflow::AllocationDescription;
78 using tensorflow::DataType;
79 using tensorflow::ExtendSessionGraphHelper;
80 using tensorflow::Graph;
81 using tensorflow::GraphDef;
82 using tensorflow::mutex_lock;
83 using tensorflow::NameRangeMap;
84 using tensorflow::NameRangesForNode;
85 using tensorflow::NewSession;
86 using tensorflow::Node;
87 using tensorflow::NodeBuilder;
88 using tensorflow::NodeDef;
89 using tensorflow::OpDef;
90 using tensorflow::OpRegistry;
91 using tensorflow::OutputTensor;
92 using tensorflow::PartialTensorShape;
93 using tensorflow::RunMetadata;
94 using tensorflow::RunOptions;
95 using tensorflow::Session;
96 using tensorflow::Status;
97 using tensorflow::string;
98 using tensorflow::Tensor;
99 using tensorflow::TensorBuffer;
100 using tensorflow::TensorId;
101 using tensorflow::TensorShape;
102 using tensorflow::TensorShapeProto;
103 using tensorflow::VersionDef;
104 using tensorflow::errors::FailedPrecondition;
105 using tensorflow::errors::InvalidArgument;
106 using tensorflow::gtl::ArraySlice;
107 using tensorflow::strings::StrCat;
108
109 extern "C" {
110
111 // --------------------------------------------------------------------------
TF_Version()112 const char* TF_Version() { return TF_VERSION_STRING; }
113
114 // --------------------------------------------------------------------------
115
116 // --------------------------------------------------------------------------
TF_NewSessionOptions()117 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)118 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
119
TF_SetTarget(TF_SessionOptions * options,const char * target)120 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
121 options->options.target = target;
122 }
123
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)124 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
125 size_t proto_len, TF_Status* status) {
126 if (!options->options.config.ParseFromArray(proto, proto_len)) {
127 status->status = InvalidArgument("Unparseable ConfigProto");
128 }
129 }
130 // --------------------------------------------------------------------------
TF_NewBuffer()131 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
132
TF_NewBufferFromString(const void * proto,size_t proto_len)133 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
134 void* copy = tensorflow::port::Malloc(proto_len);
135 memcpy(copy, proto, proto_len);
136
137 TF_Buffer* buf = new TF_Buffer;
138 buf->data = copy;
139 buf->length = proto_len;
140 buf->data_deallocator = [](void* data, size_t length) {
141 tensorflow::port::Free(data);
142 };
143 return buf;
144 }
145
TF_DeleteBuffer(TF_Buffer * buffer)146 void TF_DeleteBuffer(TF_Buffer* buffer) {
147 if (buffer == nullptr) return;
148 if (buffer->data_deallocator != nullptr) {
149 (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
150 buffer->length);
151 }
152 delete buffer;
153 }
154
TF_GetBuffer(TF_Buffer * buffer)155 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
156
TF_TensorFromProto(const TF_Buffer * from,TF_Tensor * to,TF_Status * status)157 void TF_TensorFromProto(const TF_Buffer* from, TF_Tensor* to,
158 TF_Status* status) {
159 TF_SetStatus(status, TF_OK, "");
160 tensorflow::TensorProto from_tensor_proto;
161 status->status = BufferToMessage(from, &from_tensor_proto);
162 if (!status->status.ok()) {
163 return;
164 }
165 status->status =
166 tensorflow::down_cast<tensorflow::TensorInterface*>(to->tensor)
167 ->FromProto(from_tensor_proto);
168 }
169 // --------------------------------------------------------------------------
170
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)171 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
172 TF_Status* status) {
173 Session* session;
174 status->status = NewSession(opt->options, &session);
175 if (status->status.ok()) {
176 return new TF_DeprecatedSession({session});
177 } else {
178 DCHECK_EQ(nullptr, session);
179 return nullptr;
180 }
181 }
182
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)183 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
184 status->status = s->session->Close();
185 }
186
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)187 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
188 status->status = Status::OK();
189 if (s == nullptr) return;
190 delete s->session;
191 delete s;
192 }
193
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)194 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
195 size_t proto_len, TF_Status* status) {
196 GraphDef g;
197 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
198 status->status = InvalidArgument("Invalid GraphDef");
199 return;
200 }
201 status->status = s->session->Extend(g);
202 }
203
204 } // end extern "C"
205
206 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)207 static void TF_Reset_Helper(const TF_SessionOptions* opt,
208 const char** containers, int ncontainers,
209 TF_Status* status) {
210 std::vector<string> container_names(ncontainers);
211 for (int i = 0; i < ncontainers; ++i) {
212 container_names[i] = containers[i];
213 }
214
215 status->status = Reset(opt->options, container_names);
216 }
217
218 extern "C" {
219
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)220 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
221 int ncontainers, TF_Status* status) {
222 TF_Reset_Helper(opt, containers, ncontainers, status);
223 }
224
225 } // end extern "C"
226
227 namespace tensorflow {
228
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)229 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
230 TF_Buffer* out) {
231 if (out->data != nullptr) {
232 return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
233 }
234 const size_t proto_size = in.ByteSizeLong();
235 void* buf = port::Malloc(proto_size);
236 if (buf == nullptr) {
237 return tensorflow::errors::ResourceExhausted(
238 "Failed to allocate memory to serialize message of type '",
239 in.GetTypeName(), "' and size ", proto_size);
240 }
241 if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
242 port::Free(buf);
243 return InvalidArgument("Unable to serialize ", in.GetTypeName(),
244 " protocol buffer, perhaps the serialized size (",
245 proto_size, " bytes) is too large?");
246 }
247 out->data = buf;
248 out->length = proto_size;
249 out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
250 return Status::OK();
251 }
252
BufferToMessage(const TF_Buffer * in,tensorflow::protobuf::MessageLite * out)253 Status BufferToMessage(const TF_Buffer* in,
254 tensorflow::protobuf::MessageLite* out) {
255 if (in == nullptr || !out->ParseFromArray(in->data, in->length)) {
256 return errors::InvalidArgument("Unparseable ", out->GetTypeName(),
257 " proto");
258 }
259 return Status::OK();
260 }
261
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)262 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
263 const char* mutation_type) {
264 // If any session has already run this node_id, mark this session as
265 // unrunnable.
266 for (auto it : graph->sessions) {
267 mutex_lock session_lock(it.first->mu);
268 if (it.first->last_num_graph_nodes > op.node.id()) {
269 it.second = strings::StrCat(
270 "Operation '", op.node.DebugString(), "' was changed by ",
271 mutation_type,
272 " after it was run by a session. This mutation will have no effect, "
273 "and will trigger an error in the future. Either don't modify "
274 "nodes after running them or create a new session.");
275 }
276 }
277 }
278
279 namespace {
280
281 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)282 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
283 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
284 const int64_t* dims) {
285 if (num_dims != -1) {
286 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
287 dim_vec.reserve(num_dims);
288 for (int i = 0; i < num_dims; ++i) {
289 dim_vec.push_back(ic->MakeDim(dims[i]));
290 }
291 return ic->MakeShape(dim_vec);
292 } else {
293 return ic->UnknownShape();
294 }
295 }
296
297 } // namespace
298
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)299 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
300 int num_shapes_and_types,
301 const int64_t** shapes,
302 const int* ranks,
303 const TF_DataType* types,
304 TF_Status* status) {
305 Node* node = &output.oper->node;
306
307 mutex_lock l(graph->mu);
308 tensorflow::shape_inference::InferenceContext* ic =
309 graph->refiner.GetContext(node);
310 if (ic == nullptr) {
311 status->status =
312 InvalidArgument("Node ", node->name(), " was not found in the graph");
313 return;
314 }
315
316 auto shape_and_type_vec =
317 std::vector<tensorflow::shape_inference::ShapeAndType>(
318 num_shapes_and_types);
319 for (int i = 0; i < num_shapes_and_types; ++i) {
320 tensorflow::shape_inference::ShapeHandle shape_handle =
321 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
322 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
323 shape_handle, static_cast<DataType>(types[i]));
324 }
325
326 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
327 }
328
329 // Helpers for loading a TensorFlow plugin (a .so file).
330 Status LoadDynamicLibrary(const char* library_filename, void** result,
331 const void** buf, size_t* len);
332
333 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
334 // directly, instead of requiring us to serialize to a GraphDef and
335 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)336 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
337 if (session->graph != nullptr) {
338 // Take the graph lock before the session lock to avoid deadlock. This is
339 // safe since session->graph does not change.
340 session->graph->mu.lock();
341 mutex_lock session_lock(session->mu);
342 const Graph& graph = session->graph->graph;
343
344 const string& mutation_warning = session->graph->sessions[session];
345 if (!mutation_warning.empty()) {
346 // TODO(b/74949947): turn this back into an error status
347 LOG(WARNING) << mutation_warning;
348 session->graph->sessions[session].clear();
349 }
350
351 const auto num_nodes = graph.num_node_ids();
352 if (session->last_num_graph_nodes < num_nodes) {
353 // TODO(nolivia): check this on a subset of the graph instead of all of
354 // it.
355 status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
356 if (!status->status.ok()) {
357 session->graph->mu.unlock();
358 return false;
359 }
360
361 GraphDef graph_def;
362 *graph_def.mutable_versions() = graph.versions();
363 // Fill graph_def with nodes with ids in the range
364 // [session->last_num_graph_nodes, num_nodes), that is the nodes
365 // added since the last TF_SessionRun() call.
366 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
367 Node* const node = graph.FindNodeId(id);
368 if (node != nullptr && node->IsOp()) {
369 NodeDef* const node_def = graph_def.add_node();
370 *node_def = node->def();
371 }
372 }
373 *graph_def.mutable_library() = graph.flib_def().ToProto();
374 session->graph->mu.unlock();
375 status->status = session->session->Extend(std::move(graph_def));
376 if (!status->status.ok()) {
377 // Contract is we always delete input_values[i].
378 return false;
379 }
380 // Note: session->session is not modified if Extend() fails, so
381 // we only set last_num_graph_nodes if it succeeds.
382 session->last_num_graph_nodes = num_nodes;
383 } else {
384 session->graph->mu.unlock();
385 }
386 }
387 return true;
388 }
389
390 } // namespace tensorflow
391
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)392 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
393 TF_Status* status) {
394 status->status = Status::OK();
395 for (int i = 0; i < noutputs; ++i) {
396 c_outputs[i] = nullptr;
397 }
398 }
399
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)400 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
401 std::vector<std::pair<string, Tensor>>* input_pairs,
402 TF_Status* status) {
403 const int ninputs = input_pairs->size();
404 for (int i = 0; i < ninputs; ++i) {
405 status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
406 if (!status->status.ok()) return false;
407 }
408 return true;
409 }
410
411 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
412 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)413 static TF_Tensor* EmptyTensor(TF_DataType dtype,
414 const tensorflow::TensorShape& shape) {
415 static char empty;
416 int64_t nelems = 1;
417 std::vector<tensorflow::int64> dims;
418 for (int i = 0; i < shape.dims(); ++i) {
419 dims.push_back(shape.dim_size(i));
420 nelems *= shape.dim_size(i);
421 }
422 CHECK_EQ(nelems, 0);
423 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
424 "64-bit int types should match in size");
425 return TF_NewTensor(
426 dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
427 reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
428 }
429
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)430 static void TF_Run_Helper(
431 Session* session, const char* handle, const TF_Buffer* run_options,
432 // Input tensors
433 const std::vector<std::pair<string, Tensor>>& input_pairs,
434 // Output tensors
435 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
436 // Target nodes
437 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
438 TF_Status* status) {
439 const int noutputs = output_tensor_names.size();
440 std::vector<Tensor> outputs(noutputs);
441 Status result;
442
443 if (handle == nullptr) {
444 RunOptions run_options_proto;
445 if (run_options != nullptr && !run_options_proto.ParseFromArray(
446 run_options->data, run_options->length)) {
447 status->status = InvalidArgument("Unparseable RunOptions proto");
448 return;
449 }
450 if (run_metadata != nullptr && run_metadata->data != nullptr) {
451 status->status =
452 InvalidArgument("Passing non-empty run_metadata is invalid.");
453 return;
454 }
455
456 RunMetadata run_metadata_proto;
457 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
458 target_oper_names, &outputs, &run_metadata_proto);
459
460 // Serialize back to upstream client, who now owns the new buffer
461 if (run_metadata != nullptr) {
462 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
463 if (!status->status.ok()) return;
464 }
465 } else {
466 // NOTE(zongheng): PRun does not support RunOptions yet.
467 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
468 }
469 if (!result.ok()) {
470 status->status = result;
471 return;
472 }
473
474 // Store results in c_outputs[]
475 for (int i = 0; i < noutputs; ++i) {
476 const Tensor& src = outputs[i];
477 if (!src.IsInitialized() || src.NumElements() == 0) {
478 c_outputs[i] =
479 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
480 continue;
481 }
482 c_outputs[i] = TF_TensorFromTensor(src, &status->status);
483 if (!status->status.ok()) return;
484 }
485 }
486
487 extern "C" {
488
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)489 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
490 // Input tensors
491 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
492 // Output tensors
493 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
494 // Target nodes
495 const char** c_target_oper_names, int ntargets,
496 TF_Buffer* run_metadata, TF_Status* status) {
497 TF_Run_Setup(noutputs, c_outputs, status);
498 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
499 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
500 for (int i = 0; i < ninputs; ++i) {
501 input_pairs[i].first = c_input_names[i];
502 }
503 std::vector<string> output_names(noutputs);
504 for (int i = 0; i < noutputs; ++i) {
505 output_names[i] = c_output_names[i];
506 }
507 std::vector<string> target_oper_names(ntargets);
508 for (int i = 0; i < ntargets; ++i) {
509 target_oper_names[i] = c_target_oper_names[i];
510 }
511 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
512 c_outputs, target_oper_names, run_metadata, status);
513 }
514
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)515 void TF_PRunSetup(TF_DeprecatedSession* s,
516 // Input names
517 const char** c_input_names, int ninputs,
518 // Output names
519 const char** c_output_names, int noutputs,
520 // Target nodes
521 const char** c_target_oper_names, int ntargets,
522 const char** handle, TF_Status* status) {
523 *handle = nullptr;
524
525 std::vector<string> input_names(ninputs);
526 std::vector<string> output_names(noutputs);
527 std::vector<string> target_oper_names(ntargets);
528 for (int i = 0; i < ninputs; ++i) {
529 input_names[i] = c_input_names[i];
530 }
531 for (int i = 0; i < noutputs; ++i) {
532 output_names[i] = c_output_names[i];
533 }
534 for (int i = 0; i < ntargets; ++i) {
535 target_oper_names[i] = c_target_oper_names[i];
536 }
537 string new_handle;
538 status->status = s->session->PRunSetup(input_names, output_names,
539 target_oper_names, &new_handle);
540 if (status->status.ok()) {
541 char* buf = new char[new_handle.size() + 1];
542 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
543 *handle = buf;
544 }
545 }
546
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)547 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
548 // Input tensors
549 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
550 // Output tensors
551 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
552 // Target nodes
553 const char** c_target_oper_names, int ntargets,
554 TF_Status* status) {
555 TF_Run_Setup(noutputs, c_outputs, status);
556 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
557 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
558 for (int i = 0; i < ninputs; ++i) {
559 input_pairs[i].first = c_input_names[i];
560 }
561
562 std::vector<string> output_names(noutputs);
563 for (int i = 0; i < noutputs; ++i) {
564 output_names[i] = c_output_names[i];
565 }
566 std::vector<string> target_oper_names(ntargets);
567 for (int i = 0; i < ntargets; ++i) {
568 target_oper_names[i] = c_target_oper_names[i];
569 }
570 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
571 c_outputs, target_oper_names, nullptr, status);
572 }
573
TF_LoadLibrary(const char * library_filename,TF_Status * status)574 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
575 TF_Library* lib_handle = new TF_Library;
576 status->status = tensorflow::LoadDynamicLibrary(
577 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
578 &lib_handle->op_list.length);
579 if (!status->status.ok()) {
580 delete lib_handle;
581 return nullptr;
582 }
583 return lib_handle;
584 }
585
TF_GetOpList(TF_Library * lib_handle)586 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
587
TF_DeleteLibraryHandle(TF_Library * lib_handle)588 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
589 if (lib_handle == nullptr) return;
590 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
591 delete lib_handle;
592 }
593
TF_GetAllOpList()594 TF_Buffer* TF_GetAllOpList() {
595 std::vector<tensorflow::OpDef> op_defs;
596 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
597 tensorflow::OpList op_list;
598 for (const auto& op : op_defs) {
599 *(op_list.add_op()) = op;
600 }
601 TF_Buffer* ret = TF_NewBuffer();
602 TF_CHECK_OK(MessageToBuffer(op_list, ret));
603 return ret;
604 }
605
606 // --------------------------------------------------------------------------
607 // ListDevices & SessionListDevices API
608
TF_DeleteDeviceList(TF_DeviceList * list)609 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
610
TF_SessionListDevices(TF_Session * session,TF_Status * status)611 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
612 TF_DeviceList* response = new TF_DeviceList;
613 if (session && session->session)
614 status->status = session->session->ListDevices(&response->response);
615 return response;
616 }
617
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)618 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
619 TF_Status* status) {
620 TF_DeviceList* response = new TF_DeviceList;
621 if (session && session->session)
622 status->status = session->session->ListDevices(&response->response);
623 return response;
624 }
625
TF_DeviceListCount(const TF_DeviceList * list)626 int TF_DeviceListCount(const TF_DeviceList* list) {
627 return list->response.size();
628 }
629
630 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
631 return_type method_name(const TF_DeviceList* list, const int index, \
632 TF_Status* status) { \
633 if (list == nullptr) { \
634 status->status = InvalidArgument("list is null!"); \
635 return err_val; \
636 } \
637 if (index < 0 || index >= list->response.size()) { \
638 status->status = InvalidArgument("index out of bounds"); \
639 return err_val; \
640 } \
641 status->status = Status::OK(); \
642 return list->response[index].accessor; \
643 }
644
645 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
646 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
647 nullptr);
648 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
649 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
650
651 #undef TF_DEVICELIST_METHOD
652
653 } // end extern "C"
654
655 // --------------------------------------------------------------------------
656 // New Graph and Session API
657
658 // Helper functions -----------------------------------------------------------
659
660 namespace {
661
ToOperation(Node * node)662 TF_Operation* ToOperation(Node* node) {
663 return static_cast<TF_Operation*>(static_cast<void*>(node));
664 }
665
OutputName(const TF_Output & output)666 string OutputName(const TF_Output& output) {
667 return StrCat(output.oper->node.name(), ":", output.index);
668 }
669
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)670 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
671 const char* attr_name,
672 TF_Status* status) {
673 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
674 if (attr == nullptr) {
675 status->status = InvalidArgument("Operation '", oper->node.name(),
676 "' has no attr named '", attr_name, "'.");
677 }
678 return attr;
679 }
680
ToTensorId(const TF_Output & output)681 TensorId ToTensorId(const TF_Output& output) {
682 return TensorId(output.oper->node.name(), output.index);
683 }
684
685 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)686 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
687 int n) {
688 std::vector<tensorflow::Output> outputs(n);
689 for (int i = 0; i < n; ++i) {
690 outputs[i] =
691 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
692 }
693 return outputs;
694 }
695
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)696 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
697 TF_Output* tf_outputs) {
698 for (int i = 0; i < outputs.size(); i++) {
699 tf_outputs[i].oper = ToOperation(outputs[i].node());
700 tf_outputs[i].index = outputs[i].index();
701 }
702 }
703 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
704
705 } // namespace
706
707 // Shape functions -----------------------------------------------------------
708
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)709 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
710 const int64_t* dims, const int num_dims,
711 TF_Status* status) {
712 Node* node = &output.oper->node;
713
714 mutex_lock l(graph->mu);
715 tensorflow::shape_inference::InferenceContext* ic =
716 graph->refiner.GetContext(node);
717 if (ic == nullptr) {
718 status->status =
719 InvalidArgument("Node ", node->name(), " was not found in the graph");
720 return;
721 }
722 tensorflow::shape_inference::ShapeHandle new_shape =
723 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
724 status->status = graph->refiner.SetShape(node, output.index, new_shape);
725 }
726
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)727 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
728 TF_Status* status) {
729 Node* node = &output.oper->node;
730
731 mutex_lock l(graph->mu);
732 tensorflow::shape_inference::InferenceContext* ic =
733 graph->refiner.GetContext(node);
734 if (ic == nullptr) {
735 status->status =
736 InvalidArgument("Node ", node->name(), " was not found in the graph");
737 return -1;
738 }
739
740 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
741
742 // Unknown rank means the number of dimensions is -1.
743 if (!ic->RankKnown(shape)) {
744 return -1;
745 }
746
747 return ic->Rank(shape);
748 }
749
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)750 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
751 int num_dims, TF_Status* status) {
752 Node* node = &output.oper->node;
753
754 mutex_lock l(graph->mu);
755 tensorflow::shape_inference::InferenceContext* ic =
756 graph->refiner.GetContext(node);
757 if (ic == nullptr) {
758 status->status =
759 InvalidArgument("Node ", node->name(), " was not found in the graph");
760 return;
761 }
762
763 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
764
765 int rank = -1;
766 if (ic->RankKnown(shape)) {
767 rank = ic->Rank(shape);
768 }
769
770 if (num_dims != rank) {
771 status->status = InvalidArgument("Expected rank is ", num_dims,
772 " but actual rank is ", rank);
773 return;
774 }
775
776 if (num_dims == 0) {
777 // Output shape is a scalar.
778 return;
779 }
780
781 // Rank is greater than 0, so fill in the values, if known, and
782 // -1 for unknown values.
783 for (int i = 0; i < num_dims; ++i) {
784 auto dim = ic->Dim(shape, i);
785 int64_t value = -1;
786 if (ic->ValueKnown(dim)) {
787 value = ic->Value(dim);
788 }
789 dims[i] = value;
790 }
791 }
792
793 // TF_OperationDescription functions ------------------------------------------
794
795 extern "C" {
796
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)797 TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
798 const char* op_type,
799 const char* oper_name)
800 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
801 return new TF_OperationDescription(graph, op_type, oper_name);
802 }
803
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)804 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
805 const char* oper_name) {
806 mutex_lock l(graph->mu);
807 return TF_NewOperationLocked(graph, op_type, oper_name);
808 }
809
TF_SetDevice(TF_OperationDescription * desc,const char * device)810 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
811 desc->node_builder.Device(device);
812 }
813
TF_AddInput(TF_OperationDescription * desc,TF_Output input)814 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
815 desc->node_builder.Input(&input.oper->node, input.index);
816 }
817
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)818 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
819 int num_inputs) {
820 std::vector<NodeBuilder::NodeOut> input_list;
821 input_list.reserve(num_inputs);
822 for (int i = 0; i < num_inputs; ++i) {
823 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
824 }
825 desc->node_builder.Input(input_list);
826 }
827
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)828 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
829 desc->node_builder.ControlInput(&input->node);
830 }
831
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)832 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
833 desc->colocation_constraints.emplace(
834 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
835 }
836
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)837 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
838 const void* value, size_t length) {
839 tensorflow::StringPiece s(static_cast<const char*>(value), length);
840 desc->node_builder.Attr(attr_name, s);
841 }
842
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)843 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
844 const void* const* values, const size_t* lengths,
845 int num_values) {
846 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
847 desc->colocation_constraints.clear();
848 for (int i = 0; i < num_values; ++i) {
849 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
850 lengths[i]);
851 }
852 } else {
853 std::vector<tensorflow::StringPiece> v;
854 v.reserve(num_values);
855 for (int i = 0; i < num_values; ++i) {
856 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
857 }
858 desc->node_builder.Attr(attr_name, v);
859 }
860 }
861
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)862 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
863 int64_t value) {
864 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
865 "64-bit int types should match in size");
866 desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
867 }
868
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)869 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
870 const int64_t* values, int num_values) {
871 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
872 "64-bit int types should match in size");
873 desc->node_builder.Attr(
874 attr_name,
875 ArraySlice<const tensorflow::int64>(
876 reinterpret_cast<const tensorflow::int64*>(values), num_values));
877 }
878
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)879 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
880 float value) {
881 desc->node_builder.Attr(attr_name, value);
882 }
883
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)884 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
885 const float* values, int num_values) {
886 desc->node_builder.Attr(attr_name,
887 ArraySlice<const float>(values, num_values));
888 }
889
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)890 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
891 unsigned char value) {
892 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
893 }
894
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)895 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
896 const unsigned char* values, int num_values) {
897 std::unique_ptr<bool[]> b(new bool[num_values]);
898 for (int i = 0; i < num_values; ++i) {
899 b[i] = values[i];
900 }
901 desc->node_builder.Attr(attr_name,
902 ArraySlice<const bool>(b.get(), num_values));
903 }
904
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)905 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
906 TF_DataType value) {
907 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
908 }
909
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)910 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
911 const TF_DataType* values, int num_values) {
912 desc->node_builder.Attr(
913 attr_name, ArraySlice<const DataType>(
914 reinterpret_cast<const DataType*>(values), num_values));
915 }
916
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)917 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
918 const char* placeholder) {
919 tensorflow::AttrValue attr_value;
920 attr_value.set_placeholder(placeholder);
921 desc->node_builder.Attr(attr_name, attr_value);
922 }
923
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)924 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
925 const char* value, size_t length) {
926 tensorflow::NameAttrList func_name;
927 func_name.set_name(string(value, value + length));
928 desc->node_builder.Attr(attr_name, func_name);
929 }
930
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)931 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
932 const int64_t* dims, int num_dims) {
933 PartialTensorShape shape;
934 if (num_dims >= 0) {
935 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
936 "64-bit int types should match in size");
937 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
938 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
939 }
940 desc->node_builder.Attr(attr_name, shape);
941 }
942
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)943 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
944 const int64_t* const* dims, const int* num_dims,
945 int num_shapes) {
946 std::vector<PartialTensorShape> shapes;
947 shapes.reserve(num_shapes);
948 for (int i = 0; i < num_shapes; ++i) {
949 if (num_dims[i] < 0) {
950 shapes.emplace_back();
951 } else {
952 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
953 "64-bit int types should match in size");
954 shapes.emplace_back(ArraySlice<tensorflow::int64>(
955 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
956 }
957 }
958 desc->node_builder.Attr(attr_name, shapes);
959 }
960
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)961 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
962 const char* attr_name, const void* proto,
963 size_t proto_len, TF_Status* status) {
964 // shape.ParseFromArray takes an int as length, this function takes size_t,
965 // make sure there is no information loss.
966 if (proto_len > std::numeric_limits<int>::max()) {
967 status->status = InvalidArgument(
968 "proto_len (", proto_len,
969 " bytes) is too large to be parsed by the protocol buffer library");
970 return;
971 }
972 TensorShapeProto shape;
973 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
974 desc->node_builder.Attr(attr_name, shape);
975 status->status = Status::OK();
976 } else {
977 status->status = InvalidArgument("Unparseable TensorShapeProto");
978 }
979 }
980
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)981 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
982 const char* attr_name,
983 const void* const* protos,
984 const size_t* proto_lens, int num_shapes,
985 TF_Status* status) {
986 std::vector<TensorShapeProto> shapes;
987 shapes.resize(num_shapes);
988 for (int i = 0; i < num_shapes; ++i) {
989 if (proto_lens[i] > std::numeric_limits<int>::max()) {
990 status->status = InvalidArgument(
991 "length of element ", i, " in the list (", proto_lens[i],
992 " bytes) is too large to be parsed by the protocol buffer library");
993 return;
994 }
995 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
996 status->status =
997 InvalidArgument("Unparseable TensorShapeProto at index ", i);
998 return;
999 }
1000 }
1001 desc->node_builder.Attr(attr_name, shapes);
1002 status->status = Status::OK();
1003 }
1004
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)1005 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
1006 TF_Tensor* value, TF_Status* status) {
1007 Tensor t;
1008 status->status = TF_TensorToTensor(value, &t);
1009 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1010 }
1011
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)1012 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
1013 TF_Tensor* const* values, int num_values,
1014 TF_Status* status) {
1015 status->status = Status::OK();
1016 std::vector<Tensor> t;
1017 t.reserve(num_values);
1018
1019 for (int i = 0; i < num_values && status->status.ok(); ++i) {
1020 Tensor v;
1021 status->status = TF_TensorToTensor(values[i], &v);
1022 t.emplace_back(v);
1023 }
1024
1025 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1026 }
1027
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1028 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1029 const void* proto, size_t proto_len,
1030 TF_Status* status) {
1031 tensorflow::AttrValue attr_value;
1032 if (!attr_value.ParseFromArray(proto, proto_len)) {
1033 status->status = InvalidArgument("Unparseable AttrValue proto");
1034 return;
1035 }
1036
1037 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1038 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1039 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1040 status->status =
1041 InvalidArgument("Expected \"list\" field for \"",
1042 tensorflow::kColocationAttrName, "\" attribute");
1043 return;
1044 }
1045 desc->colocation_constraints.clear();
1046 for (const string& location : attr_value.list().s()) {
1047 desc->colocation_constraints.insert(location);
1048 }
1049 } else {
1050 desc->node_builder.Attr(attr_name, std::move(attr_value));
1051 }
1052
1053 status->status = Status::OK();
1054 }
1055
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1056 TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1057 TF_Status* status)
1058 TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1059 Node* ret = nullptr;
1060
1061 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1062 status->status = InvalidArgument("Duplicate node name in graph: '",
1063 desc->node_builder.node_name(), "'");
1064 } else {
1065 if (!desc->colocation_constraints.empty()) {
1066 desc->node_builder.Attr(
1067 tensorflow::kColocationAttrName,
1068 std::vector<string>(desc->colocation_constraints.begin(),
1069 desc->colocation_constraints.end()));
1070 }
1071 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1072 /*consume=*/true);
1073
1074 if (status->status.ok()) {
1075 // Run shape inference function for newly added node.
1076 status->status = desc->graph->refiner.AddNode(ret);
1077 }
1078 if (status->status.ok()) {
1079 // Add the node to the name-to-node mapping.
1080 desc->graph->name_map[ret->name()] = ret;
1081 } else if (ret != nullptr) {
1082 desc->graph->graph.RemoveNode(ret);
1083 ret = nullptr;
1084 }
1085 }
1086
1087 delete desc;
1088
1089 return ToOperation(ret);
1090 }
1091
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1092 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1093 TF_Status* status) {
1094 mutex_lock l(desc->graph->mu);
1095 return TF_FinishOperationLocked(desc, status);
1096 }
1097
1098 // TF_Operation functions
1099 // ----------------------------------------------------------
1100
TF_OperationName(TF_Operation * oper)1101 const char* TF_OperationName(TF_Operation* oper) {
1102 return oper->node.name().c_str();
1103 }
1104
TF_OperationOpType(TF_Operation * oper)1105 const char* TF_OperationOpType(TF_Operation* oper) {
1106 return oper->node.type_string().c_str();
1107 }
1108
TF_OperationDevice(TF_Operation * oper)1109 const char* TF_OperationDevice(TF_Operation* oper) {
1110 return oper->node.requested_device().c_str();
1111 }
1112
TF_OperationNumOutputs(TF_Operation * oper)1113 int TF_OperationNumOutputs(TF_Operation* oper) {
1114 return oper->node.num_outputs();
1115 }
1116
TF_OperationOutputType(TF_Output oper_out)1117 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1118 return static_cast<TF_DataType>(
1119 oper_out.oper->node.output_type(oper_out.index));
1120 }
1121
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1122 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1123 TF_Status* status) {
1124 NameRangeMap name_ranges;
1125 status->status =
1126 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1127 if (!status->status.ok()) return -1;
1128 auto iter = name_ranges.find(arg_name);
1129 if (iter == name_ranges.end()) {
1130 status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1131 return -1;
1132 }
1133 return iter->second.second - iter->second.first;
1134 }
1135
TF_OperationNumInputs(TF_Operation * oper)1136 int TF_OperationNumInputs(TF_Operation* oper) {
1137 return oper->node.num_inputs();
1138 }
1139
TF_OperationInputType(TF_Input oper_in)1140 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1141 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1142 }
1143
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1144 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1145 TF_Status* status) {
1146 NameRangeMap name_ranges;
1147 status->status =
1148 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1149 if (!status->status.ok()) return -1;
1150 auto iter = name_ranges.find(arg_name);
1151 if (iter == name_ranges.end()) {
1152 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1153 return -1;
1154 }
1155 return iter->second.second - iter->second.first;
1156 }
1157
TF_OperationInput(TF_Input oper_in)1158 TF_Output TF_OperationInput(TF_Input oper_in) {
1159 const tensorflow::Edge* edge;
1160 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1161 if (!s.ok()) {
1162 return {nullptr, -1};
1163 }
1164
1165 return {ToOperation(edge->src()), edge->src_output()};
1166 }
1167
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1168 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1169 int max_inputs) {
1170 for (auto* edge : oper->node.in_edges()) {
1171 if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1172 inputs[edge->dst_input()] = {ToOperation(edge->src()),
1173 edge->src_output()};
1174 }
1175 }
1176 }
1177
TF_OperationOutputNumConsumers(TF_Output oper_out)1178 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1179 int count = 0;
1180 for (const auto* edge : oper_out.oper->node.out_edges()) {
1181 if (edge->src_output() == oper_out.index) {
1182 ++count;
1183 }
1184 }
1185 return count;
1186 }
1187
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1188 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1189 int max_consumers) {
1190 int count = 0;
1191 for (const auto* edge : oper_out.oper->node.out_edges()) {
1192 if (edge->src_output() == oper_out.index) {
1193 if (count < max_consumers) {
1194 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1195 }
1196 ++count;
1197 }
1198 }
1199 return count;
1200 }
1201
TF_OperationNumControlInputs(TF_Operation * oper)1202 int TF_OperationNumControlInputs(TF_Operation* oper) {
1203 int count = 0;
1204 for (const auto* edge : oper->node.in_edges()) {
1205 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1206 ++count;
1207 }
1208 }
1209 return count;
1210 }
1211
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1212 int TF_OperationGetControlInputs(TF_Operation* oper,
1213 TF_Operation** control_inputs,
1214 int max_control_inputs) {
1215 int count = 0;
1216 for (const auto* edge : oper->node.in_edges()) {
1217 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1218 if (count < max_control_inputs) {
1219 control_inputs[count] = ToOperation(edge->src());
1220 }
1221 ++count;
1222 }
1223 }
1224 return count;
1225 }
1226
TF_OperationNumControlOutputs(TF_Operation * oper)1227 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1228 int count = 0;
1229 for (const auto* edge : oper->node.out_edges()) {
1230 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1231 ++count;
1232 }
1233 }
1234 return count;
1235 }
1236
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1237 int TF_OperationGetControlOutputs(TF_Operation* oper,
1238 TF_Operation** control_outputs,
1239 int max_control_outputs) {
1240 int count = 0;
1241 for (const auto* edge : oper->node.out_edges()) {
1242 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1243 if (count < max_control_outputs) {
1244 control_outputs[count] = ToOperation(edge->dst());
1245 }
1246 ++count;
1247 }
1248 }
1249 return count;
1250 }
1251
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1252 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1253 const char* attr_name,
1254 TF_Status* status) {
1255 TF_AttrMetadata metadata;
1256 const auto* attr = GetAttrValue(oper, attr_name, status);
1257 if (!status->status.ok()) return metadata;
1258 switch (attr->value_case()) {
1259 #define SINGLE_CASE(kK, attr_type, size_expr) \
1260 case tensorflow::AttrValue::kK: \
1261 metadata.is_list = 0; \
1262 metadata.list_size = -1; \
1263 metadata.type = attr_type; \
1264 metadata.total_size = size_expr; \
1265 break;
1266
1267 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1268 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1269 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1270 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1271 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1272 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1273 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1274 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1275 #undef SINGLE_CASE
1276
1277 case tensorflow::AttrValue::kList:
1278 metadata.is_list = 1;
1279 metadata.list_size = 0;
1280 metadata.total_size = -1;
1281 #define LIST_CASE(field, attr_type, ...) \
1282 if (attr->list().field##_size() > 0) { \
1283 metadata.type = attr_type; \
1284 metadata.list_size = attr->list().field##_size(); \
1285 __VA_ARGS__; \
1286 break; \
1287 }
1288
1289 LIST_CASE(
1290 s, TF_ATTR_STRING, metadata.total_size = 0;
1291 for (int i = 0; i < attr->list().s_size();
1292 ++i) { metadata.total_size += attr->list().s(i).size(); });
1293 LIST_CASE(i, TF_ATTR_INT);
1294 LIST_CASE(f, TF_ATTR_FLOAT);
1295 LIST_CASE(b, TF_ATTR_BOOL);
1296 LIST_CASE(type, TF_ATTR_TYPE);
1297 LIST_CASE(
1298 shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1299 for (int i = 0; i < attr->list().shape_size(); ++i) {
1300 const auto& s = attr->list().shape(i);
1301 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1302 });
1303 LIST_CASE(tensor, TF_ATTR_TENSOR);
1304 LIST_CASE(tensor, TF_ATTR_FUNC);
1305 #undef LIST_CASE
1306 // All lists empty, determine the type from the OpDef.
1307 if (metadata.list_size == 0) {
1308 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1309 const auto& a = oper->node.op_def().attr(i);
1310 if (a.name() != attr_name) continue;
1311 const string& typestr = a.type();
1312 if (typestr == "list(string)") {
1313 metadata.type = TF_ATTR_STRING;
1314 } else if (typestr == "list(int)") {
1315 metadata.type = TF_ATTR_INT;
1316 } else if (typestr == "list(float)") {
1317 metadata.type = TF_ATTR_FLOAT;
1318 } else if (typestr == "list(bool)") {
1319 metadata.type = TF_ATTR_BOOL;
1320 } else if (typestr == "list(type)") {
1321 metadata.type = TF_ATTR_TYPE;
1322 } else if (typestr == "list(shape)") {
1323 metadata.type = TF_ATTR_SHAPE;
1324 } else if (typestr == "list(tensor)") {
1325 metadata.type = TF_ATTR_TENSOR;
1326 } else if (typestr == "list(func)") {
1327 metadata.type = TF_ATTR_FUNC;
1328 } else {
1329 status->status = InvalidArgument(
1330 "Attribute '", attr_name,
1331 "' has an empty value of an unrecognized type '", typestr, "'");
1332 return metadata;
1333 }
1334 }
1335 }
1336 break;
1337
1338 case tensorflow::AttrValue::kPlaceholder:
1339 metadata.is_list = 0;
1340 metadata.list_size = -1;
1341 metadata.type = TF_ATTR_PLACEHOLDER;
1342 metadata.total_size = -1;
1343 break;
1344
1345 case tensorflow::AttrValue::kFunc:
1346 metadata.is_list = 0;
1347 metadata.list_size = -1;
1348 metadata.type = TF_ATTR_FUNC;
1349 metadata.total_size = -1;
1350 break;
1351
1352 case tensorflow::AttrValue::VALUE_NOT_SET:
1353 status->status =
1354 InvalidArgument("Attribute '", attr_name, "' has no value set");
1355 break;
1356 }
1357 return metadata;
1358 }
1359
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1360 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1361 void* value, size_t max_length,
1362 TF_Status* status) {
1363 const auto* attr = GetAttrValue(oper, attr_name, status);
1364 if (!status->status.ok()) return;
1365 if (attr->value_case() != tensorflow::AttrValue::kS) {
1366 status->status =
1367 InvalidArgument("Attribute '", attr_name, "' is not a string");
1368 return;
1369 }
1370 if (max_length <= 0) {
1371 return;
1372 }
1373 const auto& s = attr->s();
1374 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1375 }
1376
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1377 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1378 void** values, size_t* lengths,
1379 int max_values, void* storage,
1380 size_t storage_size, TF_Status* status) {
1381 const auto* attr = GetAttrValue(oper, attr_name, status);
1382 if (!status->status.ok()) return;
1383 if (attr->value_case() != tensorflow::AttrValue::kList) {
1384 status->status =
1385 InvalidArgument("Value for '", attr_name, "' is not a list");
1386 return;
1387 }
1388 const auto len = std::min(max_values, attr->list().s_size());
1389 char* p = static_cast<char*>(storage);
1390 for (int i = 0; i < len; ++i) {
1391 const string& s = attr->list().s(i);
1392 values[i] = p;
1393 lengths[i] = s.size();
1394 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1395 status->status = InvalidArgument(
1396 "Not enough storage to hold the requested list of strings");
1397 return;
1398 }
1399 memcpy(values[i], s.data(), s.size());
1400 p += s.size();
1401 }
1402 }
1403
1404 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1405 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1406 TF_Status* status) { \
1407 cpp_type v; \
1408 status->status = \
1409 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1410 if (!status->status.ok()) return; \
1411 *value = static_cast<c_type>(v); \
1412 } \
1413 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1414 int max_values, TF_Status* status) { \
1415 const auto* attr = GetAttrValue(oper, attr_name, status); \
1416 if (!status->status.ok()) return; \
1417 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1418 status->status = \
1419 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1420 return; \
1421 } \
1422 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1423 for (int i = 0; i < len; ++i) { \
1424 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1425 } \
1426 }
1427 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, int64_t, i);
1428 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1429 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1430 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1431 #undef DEFINE_GETATTR
1432
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1433 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1434 int64_t* value, int num_dims, TF_Status* status) {
1435 PartialTensorShape shape;
1436 status->status =
1437 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1438 if (!status->status.ok()) return;
1439 auto len = std::min(shape.dims(), num_dims);
1440 for (int i = 0; i < len; ++i) {
1441 value[i] = shape.dim_size(i);
1442 }
1443 }
1444
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1445 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1446 int64_t** dims, int* num_dims, int num_shapes,
1447 int64_t* storage, int storage_size,
1448 TF_Status* status) {
1449 std::vector<PartialTensorShape> shapes;
1450 status->status =
1451 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1452 if (!status->status.ok()) return;
1453 auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1454 int64_t* p = storage;
1455 int storage_left = storage_size;
1456 for (int i = 0; i < len; ++i) {
1457 // shapes[i].dims() == -1 for shapes with an unknown rank.
1458 int64_t n = shapes[i].dims();
1459 num_dims[i] = n;
1460 dims[i] = p;
1461 if (n < 0) {
1462 continue;
1463 }
1464 if (storage_left < n) {
1465 status->status = InvalidArgument(
1466 "Not enough storage to hold the requested list of shapes");
1467 return;
1468 }
1469 storage_left -= n;
1470 for (int j = 0; j < n; ++j, ++p) {
1471 *p = shapes[i].dim_size(j);
1472 }
1473 }
1474 }
1475
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1476 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1477 const char* attr_name,
1478 TF_Buffer* value, TF_Status* status) {
1479 const auto* attr = GetAttrValue(oper, attr_name, status);
1480 if (!status->status.ok()) return;
1481 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1482 status->status =
1483 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1484 return;
1485 }
1486 status->status = MessageToBuffer(attr->shape(), value);
1487 }
1488
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1489 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1490 const char* attr_name,
1491 TF_Buffer** values, int max_values,
1492 TF_Status* status) {
1493 const auto* attr = GetAttrValue(oper, attr_name, status);
1494 if (!status->status.ok()) return;
1495 if (attr->value_case() != tensorflow::AttrValue::kList) {
1496 status->status =
1497 InvalidArgument("Value for '", attr_name, "' is not a list");
1498 return;
1499 }
1500 const auto len = std::min(max_values, attr->list().shape_size());
1501 for (int i = 0; i < len; ++i) {
1502 values[i] = TF_NewBuffer();
1503 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1504 if (!status->status.ok()) {
1505 // Delete everything allocated to far, the operation has failed.
1506 for (int j = 0; j <= i; ++j) {
1507 TF_DeleteBuffer(values[j]);
1508 }
1509 return;
1510 }
1511 }
1512 }
1513
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1514 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1515 TF_Tensor** value, TF_Status* status) {
1516 *value = nullptr;
1517 Tensor t;
1518 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1519 if (!status->status.ok()) return;
1520 *value = TF_TensorFromTensor(t, &status->status);
1521 }
1522
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1523 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1524 TF_Tensor** values, int max_values,
1525 TF_Status* status) {
1526 std::vector<Tensor> ts;
1527 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1528 if (!status->status.ok()) return;
1529 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1530 for (int i = 0; i < len; ++i) {
1531 values[i] = TF_TensorFromTensor(ts[i], &status->status);
1532 }
1533 }
1534
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1535 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1536 TF_Buffer* output_attr_value,
1537 TF_Status* status) {
1538 const auto* attr = GetAttrValue(oper, attr_name, status);
1539 if (!status->status.ok()) return;
1540 status->status = MessageToBuffer(*attr, output_attr_value);
1541 }
1542
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1543 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1544 TF_Status* status) {
1545 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1546 }
1547
1548 // TF_Graph functions ---------------------------------------------------------
1549
TF_Graph()1550 TF_Graph::TF_Graph()
1551 : graph(tensorflow::OpRegistry::Global()),
1552 refiner(graph.versions().producer(), graph.op_registry()),
1553 delete_requested(false),
1554 parent(nullptr),
1555 parent_inputs(nullptr) {
1556 // Tell the shape refiner to also run shape inference on functions.
1557 refiner.set_function_library_for_shape_inference(&graph.flib_def());
1558 }
1559
TF_NewGraph()1560 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1561
TF_DeleteGraph(TF_Graph * g)1562 void TF_DeleteGraph(TF_Graph* g) {
1563 if (g == nullptr) return;
1564 g->mu.lock();
1565 g->delete_requested = true;
1566 const bool del = g->sessions.empty();
1567 g->mu.unlock();
1568 if (del) delete g;
1569 }
1570
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1571 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1572 mutex_lock l(graph->mu);
1573 auto iter = graph->name_map.find(oper_name);
1574 if (iter == graph->name_map.end()) {
1575 return nullptr;
1576 } else {
1577 return ToOperation(iter->second);
1578 }
1579 }
1580
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1581 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1582 if (*pos == 0) {
1583 // Advance past the first sentinel nodes in every graph (the source & sink).
1584 *pos += 2;
1585 } else {
1586 // Advance to the next node.
1587 *pos += 1;
1588 }
1589
1590 mutex_lock l(graph->mu);
1591 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1592 Node* node = graph->graph.FindNodeId(*pos);
1593 // FindNodeId() returns nullptr for nodes that have been deleted.
1594 // We aren't currently allowing nodes to be deleted, but it is safer
1595 // to still check.
1596 if (node != nullptr) return ToOperation(node);
1597 *pos += 1;
1598 }
1599
1600 // No more nodes.
1601 return nullptr;
1602 }
1603
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1604 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1605 TF_Status* status) {
1606 GraphDef def;
1607 {
1608 mutex_lock l(graph->mu);
1609 graph->graph.ToGraphDef(&def);
1610 }
1611 status->status = MessageToBuffer(def, output_graph_def);
1612 }
1613
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1614 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1615 TF_Buffer* output_op_def, TF_Status* status) {
1616 const OpDef* op_def;
1617 {
1618 mutex_lock l(graph->mu);
1619 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1620 if (!status->status.ok()) return;
1621 }
1622 status->status = MessageToBuffer(*op_def, output_op_def);
1623 }
1624
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1625 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1626 TF_Status* status) {
1627 VersionDef versions;
1628 {
1629 mutex_lock l(graph->mu);
1630 versions = graph->graph.versions();
1631 }
1632 status->status = MessageToBuffer(versions, output_version_def);
1633 }
1634
TF_NewImportGraphDefOptions()1635 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1636 return new TF_ImportGraphDefOptions;
1637 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1638 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1639 delete opts;
1640 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1641 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1642 const char* prefix) {
1643 opts->opts.prefix = prefix;
1644 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1645 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1646 const char* device) {
1647 opts->opts.default_device = device;
1648 }
1649
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1650 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1651 unsigned char uniquify_names) {
1652 opts->opts.uniquify_names = uniquify_names;
1653 }
1654
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1655 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1656 unsigned char uniquify_prefix) {
1657 opts->opts.uniquify_prefix = uniquify_prefix;
1658 }
1659
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1660 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1661 const char* src_name,
1662 int src_index, TF_Output dst) {
1663 opts->tensor_id_data.push_back(src_name);
1664 const string& src_name_str = opts->tensor_id_data.back();
1665 // We don't need to store dst's name in tensor_id_data, since `dst` must
1666 // outlive the ImportGraphDef call.
1667 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1668 }
1669
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1670 void TF_ImportGraphDefOptionsRemapControlDependency(
1671 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1672 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1673 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1674 }
1675
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1676 extern void TF_ImportGraphDefOptionsAddControlDependency(
1677 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1678 opts->opts.control_dependencies.push_back(oper->node.name());
1679 }
1680
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1681 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1682 const char* oper_name, int index) {
1683 opts->tensor_id_data.push_back(oper_name);
1684 const string& oper_name_str = opts->tensor_id_data.back();
1685 opts->opts.return_tensors.emplace_back(oper_name_str, index);
1686 }
1687
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1688 int TF_ImportGraphDefOptionsNumReturnOutputs(
1689 const TF_ImportGraphDefOptions* opts) {
1690 return opts->opts.return_tensors.size();
1691 }
1692
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1693 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1694 const char* oper_name) {
1695 opts->opts.return_nodes.push_back(oper_name);
1696 }
1697
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1698 int TF_ImportGraphDefOptionsNumReturnOperations(
1699 const TF_ImportGraphDefOptions* opts) {
1700 return opts->opts.return_nodes.size();
1701 }
1702
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1703 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1704 int* num_outputs,
1705 TF_Output** outputs) {
1706 *num_outputs = results->return_tensors.size();
1707 *outputs = results->return_tensors.data();
1708 }
1709
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1710 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1711 int* num_opers,
1712 TF_Operation*** opers) {
1713 *num_opers = results->return_nodes.size();
1714 *opers = results->return_nodes.data();
1715 }
1716
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1717 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1718 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1719 const char*** src_names, int** src_indexes) {
1720 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1721 *src_names = results->missing_unused_key_names.data();
1722 *src_indexes = results->missing_unused_key_indexes.data();
1723 }
1724
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1725 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1726 delete results;
1727 }
1728
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1729 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1730 const TF_ImportGraphDefOptions* opts,
1731 TF_ImportGraphDefResults* tf_results,
1732 TF_Status* status)
1733 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1734 const int last_node_id = graph->graph.num_node_ids();
1735 tensorflow::ImportGraphDefResults results;
1736 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1737 &graph->refiner, &results);
1738 if (!status->status.ok()) return;
1739
1740 // Add new nodes to name_map
1741 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1742 auto* node = graph->graph.FindNodeId(i);
1743 if (node != nullptr) graph->name_map[node->name()] = node;
1744 }
1745
1746 // Populate return_tensors
1747 DCHECK(tf_results->return_tensors.empty());
1748 tf_results->return_tensors.resize(results.return_tensors.size());
1749 for (int i = 0; i < results.return_tensors.size(); ++i) {
1750 tf_results->return_tensors[i].oper =
1751 ToOperation(results.return_tensors[i].first);
1752 tf_results->return_tensors[i].index = results.return_tensors[i].second;
1753 }
1754
1755 // Populate return_nodes
1756 DCHECK(tf_results->return_nodes.empty());
1757 tf_results->return_nodes.resize(results.return_nodes.size());
1758 for (int i = 0; i < results.return_nodes.size(); ++i) {
1759 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1760 }
1761
1762 // Populate missing unused map keys
1763 DCHECK(tf_results->missing_unused_key_names.empty());
1764 DCHECK(tf_results->missing_unused_key_indexes.empty());
1765 DCHECK(tf_results->missing_unused_key_names_data.empty());
1766
1767 size_t size = results.missing_unused_input_map_keys.size();
1768 tf_results->missing_unused_key_names.resize(size);
1769 tf_results->missing_unused_key_indexes.resize(size);
1770
1771 for (int i = 0; i < size; ++i) {
1772 TensorId id = results.missing_unused_input_map_keys[i];
1773 tf_results->missing_unused_key_names_data.emplace_back(id.first);
1774 tf_results->missing_unused_key_names[i] =
1775 tf_results->missing_unused_key_names_data.back().c_str();
1776 tf_results->missing_unused_key_indexes[i] = id.second;
1777 }
1778 }
1779
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1780 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1781 TF_Graph* graph, const TF_Buffer* graph_def,
1782 const TF_ImportGraphDefOptions* options, TF_Status* status) {
1783 GraphDef def;
1784 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1785 graph_def->length)) {
1786 status->status = InvalidArgument("Invalid GraphDef");
1787 return nullptr;
1788 }
1789 auto results = new TF_ImportGraphDefResults();
1790 mutex_lock l(graph->mu);
1791 GraphImportGraphDefLocked(graph, def, options, results, status);
1792 if (!status->status.ok()) {
1793 delete results;
1794 return nullptr;
1795 }
1796 return results;
1797 }
1798
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1799 void TF_GraphImportGraphDefWithReturnOutputs(
1800 TF_Graph* graph, const TF_Buffer* graph_def,
1801 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1802 int num_return_outputs, TF_Status* status) {
1803 if (num_return_outputs != options->opts.return_tensors.size()) {
1804 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1805 options->opts.return_tensors.size(),
1806 ", got ", num_return_outputs);
1807 return;
1808 }
1809 if (num_return_outputs > 0 && return_outputs == nullptr) {
1810 status->status = InvalidArgument(
1811 "'return_outputs' must be preallocated to length ", num_return_outputs);
1812 return;
1813 }
1814 GraphDef def;
1815 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1816 graph_def->length)) {
1817 status->status = InvalidArgument("Invalid GraphDef");
1818 return;
1819 }
1820 TF_ImportGraphDefResults results;
1821 mutex_lock l(graph->mu);
1822 GraphImportGraphDefLocked(graph, def, options, &results, status);
1823 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1824 memcpy(return_outputs, results.return_tensors.data(),
1825 num_return_outputs * sizeof(TF_Output));
1826 }
1827
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1828 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1829 const TF_ImportGraphDefOptions* options,
1830 TF_Status* status) {
1831 TF_ImportGraphDefResults* results =
1832 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1833 TF_DeleteImportGraphDefResults(results);
1834 }
1835
1836 // While loop functions -------------------------------------------------------
1837
1838 namespace {
1839
1840 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1841
1842 // Creates a placeholder representing an input to the cond or body graph.
1843 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1844 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1845 TF_Output* input, TF_Status* status) {
1846 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1847 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1848 // TODO(skyewm): set placeholder shape
1849 TF_Operation* oper = TF_FinishOperation(desc, status);
1850 if (!status->status.ok()) return false;
1851 *input = {oper, 0};
1852 return true;
1853 }
1854
1855 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1856 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
1857 // will be prepended to copied node names. `control_deps` are nodes in
1858 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1859 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1860 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1861 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1862 tensorflow::ShapeRefiner* dst_refiner,
1863 const TF_Output* src_inputs,
1864 const std::vector<tensorflow::Output>& dst_inputs,
1865 const string& prefix,
1866 const std::vector<tensorflow::Operation>& control_deps,
1867 const TF_Output* nodes_to_return, int nreturn_nodes,
1868 std::vector<tensorflow::Output>* return_nodes) {
1869 DCHECK(return_nodes != nullptr);
1870 GraphDef gdef;
1871 src_graph->ToGraphDef(&gdef);
1872
1873 tensorflow::ImportGraphDefOptions opts;
1874 opts.prefix = prefix;
1875
1876 for (int i = 0; i < dst_inputs.size(); ++i) {
1877 opts.input_map[ToTensorId(src_inputs[i])] =
1878 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1879 }
1880 opts.skip_mapped_nodes = true;
1881
1882 for (const tensorflow::Operation& op : control_deps) {
1883 opts.control_dependencies.push_back(op.node()->name());
1884 }
1885
1886 for (int i = 0; i < nreturn_nodes; ++i) {
1887 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1888 }
1889
1890 // TODO(skyewm): change to OutputTensor
1891 tensorflow::ImportGraphDefResults results;
1892 TF_RETURN_IF_ERROR(
1893 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1894
1895 for (const auto& pair : results.return_tensors) {
1896 return_nodes->emplace_back(pair.first, pair.second);
1897 }
1898 return Status::OK();
1899 }
1900
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1901 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1902 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1903 params.cond_graph->parent == nullptr ||
1904 params.cond_graph->parent != params.body_graph->parent ||
1905 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1906 params.ninputs <= 0 || params.cond_inputs == nullptr ||
1907 params.body_inputs == nullptr || params.body_outputs == nullptr) {
1908 s->status = InvalidArgument(
1909 "TF_WhileParams must be created by successful TF_NewWhile() call");
1910 return false;
1911 }
1912 return true;
1913 }
1914
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1915 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1916 if (params.cond_output.oper == nullptr) {
1917 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1918 return false;
1919 }
1920 for (int i = 0; i < params.ninputs; ++i) {
1921 if (params.body_outputs[i].oper == nullptr) {
1922 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1923 "field isn't set");
1924 return false;
1925 }
1926 }
1927 if (params.name == nullptr) {
1928 s->status = InvalidArgument("TF_WhileParams `name` field is null");
1929 return false;
1930 }
1931 return true;
1932 }
1933
1934 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1935
FreeWhileResources(const TF_WhileParams * params)1936 void FreeWhileResources(const TF_WhileParams* params) {
1937 TF_DeleteGraph(params->cond_graph);
1938 TF_DeleteGraph(params->body_graph);
1939 delete[] params->cond_inputs;
1940 delete[] params->body_inputs;
1941 delete[] params->body_outputs;
1942 }
1943
EmptyWhileParams()1944 TF_WhileParams EmptyWhileParams() {
1945 return {0, nullptr, nullptr, {nullptr, 0},
1946 nullptr, nullptr, nullptr, nullptr};
1947 }
1948
1949 } // namespace
1950
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1951 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1952 TF_Status* status) {
1953 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1954 status->status = tensorflow::errors::Unimplemented(
1955 "Creating while loops is not supported on mobile. File a bug at "
1956 "https://github.com/tensorflow/tensorflow/issues if this feature is "
1957 "important to you");
1958 return EmptyWhileParams();
1959 #else
1960 if (ninputs == 0) {
1961 status->status =
1962 InvalidArgument("TF_NewWhile() must be passed at least one input");
1963 return EmptyWhileParams();
1964 }
1965
1966 TF_Graph* cond_graph = TF_NewGraph();
1967 TF_Graph* body_graph = TF_NewGraph();
1968 cond_graph->parent = g;
1969 cond_graph->parent_inputs = inputs;
1970 body_graph->parent = g;
1971 body_graph->parent_inputs = inputs;
1972
1973 TF_Output* cond_inputs = new TF_Output[ninputs];
1974 TF_Output cond_output = {nullptr, -1};
1975 TF_Output* body_inputs = new TF_Output[ninputs];
1976 TF_Output* body_outputs = new TF_Output[ninputs];
1977 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1978 const char* name = nullptr;
1979
1980 for (int i = 0; i < ninputs; ++i) {
1981 // TODO(skyewm): prefix names with underscore (requires some plumbing)
1982 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1983 &cond_inputs[i], status)) {
1984 break;
1985 }
1986 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1987 &body_inputs[i], status)) {
1988 break;
1989 }
1990 }
1991
1992 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
1993 body_graph, body_inputs, body_outputs, name};
1994
1995 if (!status->status.ok()) {
1996 FreeWhileResources(¶ms);
1997 return EmptyWhileParams();
1998 }
1999 return params;
2000 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2001 }
2002
2003 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2004 namespace {
2005
2006 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2007 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2008 TF_Output* outputs) {
2009 if (!ValidateInputWhileParams(*params, status)) return;
2010
2011 TF_Graph* parent = params->cond_graph->parent;
2012 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2013 int num_loop_vars = params->ninputs;
2014
2015 mutex_lock l(parent->mu);
2016
2017 // 'cond_fn' copies the cond graph into the parent graph.
2018 tensorflow::ops::CondGraphBuilderFn cond_fn =
2019 [params, parent](const tensorflow::Scope& scope,
2020 const std::vector<tensorflow::Output>& inputs,
2021 tensorflow::Output* output) {
2022 DCHECK_EQ(scope.graph(), &parent->graph);
2023 std::vector<tensorflow::Output> cond_output;
2024 TF_RETURN_IF_ERROR(CopyGraph(
2025 ¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
2026 params->cond_inputs, inputs, scope.impl()->name(),
2027 scope.impl()->control_deps(), ¶ms->cond_output,
2028 /* nreturn_nodes */ 1, &cond_output));
2029 *output = cond_output[0];
2030 return Status::OK();
2031 };
2032
2033 // 'body_fn' copies the body graph into the parent graph.
2034 tensorflow::ops::BodyGraphBuilderFn body_fn =
2035 [params, parent, num_loop_vars](
2036 const tensorflow::Scope& scope,
2037 const std::vector<tensorflow::Output>& inputs,
2038 std::vector<tensorflow::Output>* outputs) {
2039 DCHECK_EQ(scope.graph(), &parent->graph);
2040 TF_RETURN_IF_ERROR(
2041 CopyGraph(¶ms->body_graph->graph, &parent->graph,
2042 &parent->refiner, params->body_inputs, inputs,
2043 scope.impl()->name(), scope.impl()->control_deps(),
2044 params->body_outputs, num_loop_vars, outputs));
2045 return Status::OK();
2046 };
2047
2048 // Create the while loop using an internal scope.
2049 tensorflow::Scope scope =
2050 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2051 .NewSubScope(params->name);
2052
2053 const int first_new_node_id = parent->graph.num_node_ids();
2054
2055 tensorflow::OutputList loop_outputs;
2056 status->status = tensorflow::ops::BuildWhileLoop(
2057 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2058 body_fn, params->name, &loop_outputs);
2059
2060 // Update name_map with newly-created ops.
2061 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2062 // a bad status. Once we fix this, we may want to return early instead of
2063 // executing the following code.
2064 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2065 Node* new_node = parent->graph.FindNodeId(i);
2066 if (new_node == nullptr) continue;
2067 parent->name_map[new_node->name()] = new_node;
2068 }
2069
2070 // Populate 'outputs'.
2071 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2072 for (int i = 0; i < loop_outputs.size(); ++i) {
2073 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2074 }
2075 }
2076
2077 } // namespace
2078 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2079
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2080 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2081 TF_Output* outputs) {
2082 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2083 status->status = tensorflow::errors::Unimplemented(
2084 "Creating while loops is not supported on mobile. File a bug at "
2085 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2086 "important to you");
2087 #else
2088 // If it appears the caller created or modified `params`, don't free resources
2089 if (!ValidateConstWhileParams(*params, status)) return;
2090 TF_FinishWhileHelper(params, status, outputs);
2091 FreeWhileResources(params);
2092 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2093 }
2094
TF_AbortWhile(const TF_WhileParams * params)2095 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2096
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2097 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2098 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2099 TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2100 }
2101
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2102 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2103 int ny, TF_Output* x, int nx, TF_Output* dx,
2104 TF_Status* status, TF_Output* dy) {
2105 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2106 status->status = tensorflow::errors::Unimplemented(
2107 "Adding gradients is not supported on mobile. File a bug at "
2108 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2109 "important to you");
2110 #else
2111 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2112 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2113 std::vector<tensorflow::Output> dy_arg;
2114
2115 {
2116 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2117 mutex_lock graph_lock(g->mu);
2118
2119 const int first_new_node_id = g->graph.num_node_ids();
2120
2121 string prefix_cmp;
2122 const char* child_scope_name;
2123 if (prefix == nullptr) {
2124 child_scope_name = "gradients";
2125 } else {
2126 prefix_cmp = string(prefix) + "/";
2127 // The operation should fail if the provided name prefix has already been
2128 // used in this graph
2129 for (const auto& pair : g->name_map) {
2130 const string& name = pair.first;
2131 if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2132 status->status = InvalidArgument(
2133 "prefix [", prefix,
2134 "] conflicts with existing node in the graph named [", name, "]");
2135 return;
2136 }
2137 }
2138 child_scope_name = prefix;
2139 }
2140 tensorflow::Scope scope =
2141 NewInternalScope(&g->graph, &status->status, &g->refiner)
2142 .NewSubScope(child_scope_name);
2143
2144 if (dx != nullptr) {
2145 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2146 status->status =
2147 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2148 } else {
2149 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2150 }
2151
2152 // Update g->name_map with the name_map from the scope, which will contain
2153 // the new gradient ops.
2154 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2155 Node* n = g->graph.FindNodeId(i);
2156 if (n == nullptr) continue;
2157
2158 // Adding the gradients to the graph can alter the prefix to prevent
2159 // name collisions only if this prefix has not been provided explicitly
2160 // by the user. If it was provided, assert that it remained intact.
2161 if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2162 status->status = tensorflow::errors::Internal(
2163 "BUG: The gradients prefix have been unexpectedly altered when "
2164 "adding the nodes to the graph. This is a bug. Please file an "
2165 "issue at https://github.com/tensorflow/tensorflow/issues.");
2166 return;
2167 }
2168 // We have a convoluted scheme here: Using the C++ graph construction API
2169 // to add potentially many nodes to the graph without running the checks
2170 // (such as uniqueness of the names of nodes) we run with other functions
2171 // that add a node to the graph (like TF_FinishOperation).
2172 if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2173 status->status = tensorflow::errors::Internal(
2174 "BUG: The API allowed construction of a graph with duplicate node "
2175 "names (",
2176 n->name(),
2177 "). This is a bug. Please file an issue at "
2178 "https://github.com/tensorflow/tensorflow/issues.");
2179 }
2180 }
2181 }
2182
2183 // Unpack the results from grad_outputs_arg.
2184 TFOutputsFromOutputs(dy_arg, dy);
2185 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2186 }
2187
2188 // TF_Session functions ----------------------------------------------
2189
TF_Session(tensorflow::Session * s,TF_Graph * g)2190 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2191 : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2192
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2193 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2194 TF_Status* status) {
2195 Session* session;
2196 status->status = NewSession(opt->options, &session);
2197 if (status->status.ok()) {
2198 TF_Session* new_session = new TF_Session(session, graph);
2199 if (graph != nullptr) {
2200 mutex_lock l(graph->mu);
2201 graph->sessions[new_session] = "";
2202 }
2203 return new_session;
2204 } else {
2205 LOG(ERROR) << status->status;
2206 DCHECK_EQ(nullptr, session);
2207 return nullptr;
2208 }
2209 }
2210
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2211 TF_Session* TF_LoadSessionFromSavedModel(
2212 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2213 const char* export_dir, const char* const* tags, int tags_len,
2214 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2215 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2216 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2217 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2218 status->status = tensorflow::errors::Unimplemented(
2219 "Loading a SavedModel is not supported on mobile. File a bug at "
2220 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2221 "important to you");
2222 return nullptr;
2223 #else
2224 mutex_lock l(graph->mu);
2225 if (!graph->name_map.empty()) {
2226 status->status = InvalidArgument("Graph is non-empty.");
2227 return nullptr;
2228 }
2229
2230 RunOptions run_options_proto;
2231 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2232 run_options->data, run_options->length)) {
2233 status->status = InvalidArgument("Unparseable RunOptions proto");
2234 return nullptr;
2235 }
2236
2237 std::unordered_set<string> tag_set;
2238 for (int i = 0; i < tags_len; i++) {
2239 tag_set.insert(string(tags[i]));
2240 }
2241
2242 tensorflow::SavedModelBundle bundle;
2243 status->status =
2244 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2245 export_dir, tag_set, &bundle);
2246 if (!status->status.ok()) return nullptr;
2247
2248 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2249 // extends using GraphDefs. The Graph instance is different, but equivalent
2250 // to the one used to create the session.
2251 //
2252 // TODO(jhseu): When Session is modified to take Graphs instead of
2253 // GraphDefs, return the Graph generated in LoadSavedModel().
2254 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2255 TF_ImportGraphDefResults results;
2256 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2257 import_opts, &results, status);
2258 TF_DeleteImportGraphDefOptions(import_opts);
2259 if (!status->status.ok()) return nullptr;
2260
2261 if (meta_graph_def != nullptr) {
2262 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2263 if (!status->status.ok()) return nullptr;
2264 }
2265
2266 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2267
2268 graph->sessions[session] = "";
2269 session->last_num_graph_nodes = graph->graph.num_node_ids();
2270 return session;
2271 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2272 }
2273
TF_CloseSession(TF_Session * s,TF_Status * status)2274 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2275 status->status = s->session->Close();
2276 }
2277
TF_DeleteSession(TF_Session * s,TF_Status * status)2278 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2279 status->status = Status::OK();
2280 if (s == nullptr) return;
2281 TF_Graph* const graph = s->graph;
2282 if (graph != nullptr) {
2283 graph->mu.lock();
2284 graph->sessions.erase(s);
2285 const bool del = graph->delete_requested && graph->sessions.empty();
2286 graph->mu.unlock();
2287 if (del) delete graph;
2288 }
2289 delete s->session;
2290 delete s;
2291 }
2292
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2293 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2294 const TF_Output* inputs, TF_Tensor* const* input_values,
2295 int ninputs, const TF_Output* outputs,
2296 TF_Tensor** output_values, int noutputs,
2297 const TF_Operation* const* target_opers, int ntargets,
2298 TF_Buffer* run_metadata, TF_Status* status) {
2299 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2300 // directly, instead of requiring us to serialize to a GraphDef and
2301 // call Session::Extend().
2302 if (session->extend_before_run &&
2303 !ExtendSessionGraphHelper(session, status)) {
2304 return;
2305 }
2306
2307 TF_Run_Setup(noutputs, output_values, status);
2308
2309 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2310 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2311 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2312 for (int i = 0; i < ninputs; ++i) {
2313 input_pairs[i].first = OutputName(inputs[i]);
2314 }
2315
2316 // Convert from TF_Output to string names.
2317 std::vector<string> output_names(noutputs);
2318 for (int i = 0; i < noutputs; ++i) {
2319 output_names[i] = OutputName(outputs[i]);
2320 }
2321
2322 // Convert from TF_Operation* to string names.
2323 std::vector<string> target_names(ntargets);
2324 for (int i = 0; i < ntargets; ++i) {
2325 target_names[i] = target_opers[i]->node.name();
2326 }
2327
2328 // Actually run.
2329 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2330 output_names, output_values, target_names, run_metadata,
2331 status);
2332 }
2333
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2334 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2335 int ninputs, const TF_Output* outputs, int noutputs,
2336 const TF_Operation* const* target_opers, int ntargets,
2337 const char** handle, TF_Status* status) {
2338 *handle = nullptr;
2339
2340 if (session->extend_before_run &&
2341 !ExtendSessionGraphHelper(session, status)) {
2342 return;
2343 }
2344
2345 std::vector<string> input_names(ninputs);
2346 for (int i = 0; i < ninputs; ++i) {
2347 input_names[i] = OutputName(inputs[i]);
2348 }
2349
2350 std::vector<string> output_names(noutputs);
2351 for (int i = 0; i < noutputs; ++i) {
2352 output_names[i] = OutputName(outputs[i]);
2353 }
2354
2355 std::vector<string> target_names(ntargets);
2356 for (int i = 0; i < ntargets; ++i) {
2357 target_names[i] = target_opers[i]->node.name();
2358 }
2359
2360 string new_handle;
2361 status->status = session->session->PRunSetup(input_names, output_names,
2362 target_names, &new_handle);
2363 if (status->status.ok()) {
2364 char* buf = new char[new_handle.size() + 1];
2365 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2366 *handle = buf;
2367 }
2368 }
2369
TF_DeletePRunHandle(const char * handle)2370 void TF_DeletePRunHandle(const char* handle) {
2371 delete[] handle;
2372 // TODO(suharshs): Free up any resources held by the partial run state.
2373 }
2374
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2375 void TF_SessionPRun(TF_Session* session, const char* handle,
2376 const TF_Output* inputs, TF_Tensor* const* input_values,
2377 int ninputs, const TF_Output* outputs,
2378 TF_Tensor** output_values, int noutputs,
2379 const TF_Operation* const* target_opers, int ntargets,
2380 TF_Status* status) {
2381 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2382 // directly, instead of requiring us to serialize to a GraphDef and
2383 // call Session::Extend().
2384 if (session->extend_before_run &&
2385 !ExtendSessionGraphHelper(session, status)) {
2386 return;
2387 }
2388
2389 TF_Run_Setup(noutputs, output_values, status);
2390
2391 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2392 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2393 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2394 for (int i = 0; i < ninputs; ++i) {
2395 input_pairs[i].first = OutputName(inputs[i]);
2396 }
2397
2398 // Convert from TF_Output to string names.
2399 std::vector<string> output_names(noutputs);
2400 for (int i = 0; i < noutputs; ++i) {
2401 output_names[i] = OutputName(outputs[i]);
2402 }
2403
2404 // Convert from TF_Operation* to string names.
2405 std::vector<string> target_names(ntargets);
2406 for (int i = 0; i < ntargets; ++i) {
2407 target_names[i] = target_opers[i]->node.name();
2408 }
2409
2410 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2411 output_values, target_names, nullptr, status);
2412 }
2413
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2414 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2415 TF_Tensor** result, TF_Status* status) {
2416 *result = nullptr;
2417 mutex_lock l(graph->mu);
2418 OutputTensor tensor(&output.oper->node, output.index);
2419 bool evaluated;
2420 Tensor result_tensor;
2421 status->status = EvaluateConstantTensor(
2422 tensor, graph->refiner, *graph->graph.op_registry(),
2423 graph->graph.versions().producer(), &evaluated, &result_tensor);
2424 if (evaluated) {
2425 DCHECK(status->status.ok());
2426 *result = TF_TensorFromTensor(result_tensor, &status->status);
2427 if (!status->status.ok()) evaluated = false;
2428 }
2429 return evaluated;
2430 }
2431
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2432 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2433 tensorflow::OpList op_list;
2434 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2435 status->status = InvalidArgument("Unparseable OpList");
2436 return nullptr;
2437 }
2438 status->status = Status::OK();
2439 return new TF_ApiDefMap(op_list);
2440 }
2441
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2442 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2443
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2444 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2445 size_t text_len, TF_Status* status) {
2446 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2447 status->status = tensorflow::errors::Unimplemented(
2448 "ApiDefMap is not supported on mobile.");
2449 #else
2450 mutex_lock l(api_def_map->lock);
2451 if (api_def_map->update_docs_called) {
2452 status->status = FailedPrecondition(
2453 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2454 "called.");
2455 return;
2456 }
2457 string api_def_text(text, text_len);
2458 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2459 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2460 }
2461
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2462 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2463 size_t name_len, TF_Status* status) {
2464 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2465 status->status = tensorflow::errors::Unimplemented(
2466 "ApiDefMap is not supported on mobile.");
2467 return nullptr;
2468 #else
2469 mutex_lock l(api_def_map->lock);
2470 if (!api_def_map->update_docs_called) {
2471 api_def_map->api_def_map.UpdateDocs();
2472 api_def_map->update_docs_called = true;
2473 }
2474 string name_str(name, name_len);
2475 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2476 if (api_def == nullptr) {
2477 return nullptr;
2478 }
2479
2480 TF_Buffer* ret = TF_NewBuffer();
2481 status->status = MessageToBuffer(*api_def, ret);
2482 if (!status->status.ok()) {
2483 TF_DeleteBuffer(ret);
2484 return nullptr;
2485 }
2486 return ret;
2487 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2488 }
2489
TF_GetAllRegisteredKernels(TF_Status * status)2490 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2491 tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2492 TF_Buffer* ret = TF_NewBuffer();
2493 status->status = MessageToBuffer(kernel_list, ret);
2494 if (!status->status.ok()) {
2495 TF_DeleteBuffer(ret);
2496 return nullptr;
2497 }
2498 return ret;
2499 }
2500
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2501 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2502 tensorflow::KernelList kernel_list =
2503 tensorflow::GetRegisteredKernelsForOp(name);
2504 TF_Buffer* ret = TF_NewBuffer();
2505 status->status = MessageToBuffer(kernel_list, ret);
2506 if (!status->status.ok()) {
2507 TF_DeleteBuffer(ret);
2508 return nullptr;
2509 }
2510 return ret;
2511 }
2512
TF_UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)2513 void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2514 TF_Status* status) {
2515 using tensorflow::RecordMutation;
2516 mutex_lock l(graph->mu);
2517 tensorflow::shape_inference::InferenceContext* ic =
2518 graph->refiner.GetContext(&new_src.oper->node);
2519
2520 if (ic->num_outputs() <= new_src.index) {
2521 status->status = tensorflow::errors::OutOfRange(
2522 "Cannot update edge. Output index [", new_src.index,
2523 "] is greater than the number of total outputs [", ic->num_outputs(),
2524 "].");
2525 return;
2526 }
2527 tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
2528
2529 tensorflow::shape_inference::InferenceContext* ic_dst =
2530 graph->refiner.GetContext(&dst.oper->node);
2531 if (ic_dst->num_inputs() <= dst.index) {
2532 status->status = tensorflow::errors::OutOfRange(
2533 "Cannot update edge. Input index [", dst.index,
2534 "] is greater than the number of total inputs [", ic_dst->num_inputs(),
2535 "].");
2536 return;
2537 }
2538 if (!ic_dst->MergeInput(dst.index, shape)) {
2539 status->status = tensorflow::errors::InvalidArgument(
2540 "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
2541 " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
2542 return;
2543 }
2544 status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
2545 &dst.oper->node, dst.index);
2546
2547 if (TF_GetCode(status) == TF_OK) {
2548 // This modification only updates the destination node for
2549 // the purposes of running this graph in a session. Thus, we don't
2550 // record the source node as being modified.
2551 RecordMutation(graph, *dst.oper, "updating input tensor");
2552 }
2553 }
2554
2555 // TF_Server functions ----------------------------------------------
2556
2557 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2558 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2559 : target(server->target()), server(std::move(server)) {}
2560 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2561
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2562 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2563 TF_Status* status) {
2564 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2565 status->status = tensorflow::errors::Unimplemented(
2566 "Server functionality is not supported on mobile");
2567 return nullptr;
2568 #else
2569 tensorflow::ServerDef server_def;
2570 if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2571 status->status = InvalidArgument(
2572 "Could not parse provided bytes into a ServerDef protocol buffer");
2573 return nullptr;
2574 }
2575
2576 std::unique_ptr<tensorflow::ServerInterface> out_server;
2577 status->status = tensorflow::NewServer(server_def, &out_server);
2578 if (!status->status.ok()) return nullptr;
2579
2580 return new TF_Server(std::move(out_server));
2581 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2582 }
2583
TF_ServerStart(TF_Server * server,TF_Status * status)2584 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2585 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2586 status->status = tensorflow::errors::Unimplemented(
2587 "Server functionality is not supported on mobile");
2588 #else
2589 status->status = server->server->Start();
2590 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2591 }
2592
TF_ServerStop(TF_Server * server,TF_Status * status)2593 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2594 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2595 status->status = tensorflow::errors::Unimplemented(
2596 "Server functionality is not supported on mobile");
2597 #else
2598 status->status = server->server->Stop();
2599 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2600 }
2601
TF_ServerJoin(TF_Server * server,TF_Status * status)2602 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2603 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2604 status->status = tensorflow::errors::Unimplemented(
2605 "Server functionality is not supported on mobile");
2606 #else
2607 status->status = server->server->Join();
2608 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2609 }
2610
TF_ServerTarget(TF_Server * server)2611 const char* TF_ServerTarget(TF_Server* server) {
2612 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2613 return nullptr;
2614 #else
2615 return server->target.c_str();
2616 #endif
2617 }
2618
TF_DeleteServer(TF_Server * server)2619 void TF_DeleteServer(TF_Server* server) {
2620 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2621 delete server;
2622 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2623 }
2624
TF_RegisterLogListener(void (* listener)(const char *))2625 void TF_RegisterLogListener(void (*listener)(const char*)) {
2626 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2627 tensorflow::logging::RegisterListener(listener);
2628 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2629 }
2630
TF_RegisterFilesystemPlugin(const char * plugin_filename,TF_Status * status)2631 void TF_RegisterFilesystemPlugin(const char* plugin_filename,
2632 TF_Status* status) {
2633 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2634 status->status = tensorflow::errors::Unimplemented(
2635 "FileSystem plugin functionality is not supported on mobile");
2636 #else
2637 status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
2638 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2639 }
2640
2641 } // end extern "C"
2642