1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h" // NOLINT
26
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
29 #include "tensorflow/cc/framework/gradients.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/cc/framework/scope_internal.h"
32 #include "tensorflow/cc/ops/while_loop.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/framework/logging.h"
36 #include "tensorflow/core/framework/op_gen_lib.h"
37 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 #include "tensorflow/c/c_api_internal.h"
39 #include "tensorflow/c/tf_buffer_internal.h"
40 #include "tensorflow/c/tf_status_internal.h"
41 #include "tensorflow/c/tf_tensor.h"
42 #include "tensorflow/core/common_runtime/device_mgr.h"
43 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
44 #include "tensorflow/core/common_runtime/graph_constructor.h"
45 #include "tensorflow/core/common_runtime/shape_refiner.h"
46 #include "tensorflow/core/framework/allocation_description.pb.h"
47 #include "tensorflow/core/framework/kernel_def.pb.h"
48 #include "tensorflow/core/framework/log_memory.h"
49 #include "tensorflow/core/framework/node_def_util.h"
50 #include "tensorflow/core/framework/op_kernel.h"
51 #include "tensorflow/core/framework/partial_tensor_shape.h"
52 #include "tensorflow/core/framework/tensor.h"
53 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
54 #include "tensorflow/core/framework/tensor_shape.h"
55 #include "tensorflow/core/framework/tensor_shape.pb.h"
56 #include "tensorflow/core/framework/types.h"
57 #include "tensorflow/core/framework/versions.pb.h"
58 #include "tensorflow/core/graph/graph.h"
59 #include "tensorflow/core/graph/node_builder.h"
60 #include "tensorflow/core/graph/validate.h"
61 #include "tensorflow/core/lib/gtl/array_slice.h"
62 #include "tensorflow/core/platform/coding.h"
63 #include "tensorflow/core/platform/errors.h"
64 #include "tensorflow/core/platform/mem.h"
65 #include "tensorflow/core/platform/mutex.h"
66 #include "tensorflow/core/platform/protobuf.h"
67 #include "tensorflow/core/platform/status.h"
68 #include "tensorflow/core/platform/str_util.h"
69 #include "tensorflow/core/platform/strcat.h"
70 #include "tensorflow/core/platform/stringpiece.h"
71 #include "tensorflow/core/platform/thread_annotations.h"
72 #include "tensorflow/core/platform/types.h"
73 #include "tensorflow/core/public/session.h"
74 #include "tensorflow/core/public/version.h"
75
76 // The implementation below is at the top level instead of the
77 // brain namespace because we are defining 'extern "C"' functions.
78 using tensorflow::AllocationDescription;
79 using tensorflow::AttrValueMap;
80 using tensorflow::DataType;
81 using tensorflow::ExtendSessionGraphHelper;
82 using tensorflow::Graph;
83 using tensorflow::GraphDef;
84 using tensorflow::mutex_lock;
85 using tensorflow::NameRangeMap;
86 using tensorflow::NameRangesForNode;
87 using tensorflow::NewSession;
88 using tensorflow::Node;
89 using tensorflow::NodeBuilder;
90 using tensorflow::NodeDef;
91 using tensorflow::OpDef;
92 using tensorflow::OpRegistry;
93 using tensorflow::OutputTensor;
94 using tensorflow::PartialTensorShape;
95 using tensorflow::RunMetadata;
96 using tensorflow::RunOptions;
97 using tensorflow::Session;
98 using tensorflow::Status;
99 using tensorflow::string;
100 using tensorflow::Tensor;
101 using tensorflow::TensorBuffer;
102 using tensorflow::TensorId;
103 using tensorflow::TensorShape;
104 using tensorflow::TensorShapeProto;
105 using tensorflow::VersionDef;
106 using tensorflow::errors::FailedPrecondition;
107 using tensorflow::errors::InvalidArgument;
108 using tensorflow::errors::OutOfRange;
109 using tensorflow::gtl::ArraySlice;
110 using tensorflow::strings::StrCat;
111
112 extern "C" {
113
114 // --------------------------------------------------------------------------
TF_Version()115 const char* TF_Version() { return TF_VERSION_STRING; }
116
117 // --------------------------------------------------------------------------
118
119 // --------------------------------------------------------------------------
TF_NewSessionOptions()120 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)121 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
122
TF_SetTarget(TF_SessionOptions * options,const char * target)123 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
124 options->options.target = target;
125 }
126
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)127 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
128 size_t proto_len, TF_Status* status) {
129 if (!options->options.config.ParseFromArray(proto, proto_len)) {
130 status->status = InvalidArgument("Unparseable ConfigProto");
131 }
132 }
133
TF_TensorFromProto(const TF_Buffer * from,TF_Tensor * to,TF_Status * status)134 void TF_TensorFromProto(const TF_Buffer* from, TF_Tensor* to,
135 TF_Status* status) {
136 TF_SetStatus(status, TF_OK, "");
137 tensorflow::TensorProto from_tensor_proto;
138 status->status = BufferToMessage(from, &from_tensor_proto);
139 if (!status->status.ok()) {
140 return;
141 }
142 status->status =
143 tensorflow::down_cast<tensorflow::TensorInterface*>(to->tensor)
144 ->FromProto(from_tensor_proto);
145 }
146 // --------------------------------------------------------------------------
147
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)148 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
149 TF_Status* status) {
150 Session* session;
151 status->status = NewSession(opt->options, &session);
152 if (status->status.ok()) {
153 return new TF_DeprecatedSession({session});
154 } else {
155 DCHECK_EQ(nullptr, session);
156 return nullptr;
157 }
158 }
159
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)160 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
161 status->status = s->session->Close();
162 }
163
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)164 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
165 status->status = ::tensorflow::OkStatus();
166 if (s == nullptr) return;
167 delete s->session;
168 delete s;
169 }
170
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)171 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
172 size_t proto_len, TF_Status* status) {
173 GraphDef g;
174 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
175 status->status = InvalidArgument("Invalid GraphDef");
176 return;
177 }
178 status->status = s->session->Extend(g);
179 }
180
181 } // end extern "C"
182
183 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)184 static void TF_Reset_Helper(const TF_SessionOptions* opt,
185 const char** containers, int ncontainers,
186 TF_Status* status) {
187 std::vector<string> container_names(ncontainers);
188 for (int i = 0; i < ncontainers; ++i) {
189 container_names[i] = containers[i];
190 }
191
192 status->status = Reset(opt->options, container_names);
193 }
194
195 extern "C" {
196
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)197 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
198 int ncontainers, TF_Status* status) {
199 TF_Reset_Helper(opt, containers, ncontainers, status);
200 }
201
202 } // end extern "C"
203
204 namespace tensorflow {
205
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)206 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
207 const char* mutation_type) {
208 // If any session has already run this node_id, mark this session as
209 // unrunnable.
210 for (auto it : graph->sessions) {
211 mutex_lock session_lock(it.first->mu);
212 if (it.first->last_num_graph_nodes > op.node.id()) {
213 it.second = strings::StrCat(
214 "Operation '", op.node.DebugString(), "' was changed by ",
215 mutation_type,
216 " after it was run by a session. This mutation will have no effect, "
217 "and will trigger an error in the future. Either don't modify "
218 "nodes after running them or create a new session.");
219 }
220 }
221 }
222
223 namespace {
224
225 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)226 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
227 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
228 const int64_t* dims) {
229 if (num_dims != -1) {
230 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
231 dim_vec.reserve(num_dims);
232 for (int i = 0; i < num_dims; ++i) {
233 dim_vec.push_back(ic->MakeDim(dims[i]));
234 }
235 return ic->MakeShape(dim_vec);
236 } else {
237 return ic->UnknownShape();
238 }
239 }
240
241 } // namespace
242
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)243 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
244 int num_shapes_and_types,
245 const int64_t** shapes,
246 const int* ranks,
247 const TF_DataType* types,
248 TF_Status* status) {
249 Node* node = &output.oper->node;
250
251 mutex_lock l(graph->mu);
252 tensorflow::shape_inference::InferenceContext* ic =
253 graph->refiner.GetContext(node);
254 if (ic == nullptr) {
255 status->status =
256 InvalidArgument("Node ", node->name(), " was not found in the graph");
257 return;
258 }
259
260 auto shape_and_type_vec =
261 std::vector<tensorflow::shape_inference::ShapeAndType>(
262 num_shapes_and_types);
263 for (int i = 0; i < num_shapes_and_types; ++i) {
264 tensorflow::shape_inference::ShapeHandle shape_handle =
265 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
266 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
267 shape_handle, static_cast<DataType>(types[i]));
268 }
269
270 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
271 }
272
273 // Helpers for loading a TensorFlow plugin (a .so file).
274 Status LoadDynamicLibrary(const char* library_filename, void** result,
275 const void** buf, size_t* len);
276
277 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
278 // directly, instead of requiring us to serialize to a GraphDef and
279 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)280 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
281 if (session->graph != nullptr) {
282 // Take the graph lock before the session lock to avoid deadlock. This is
283 // safe since session->graph does not change.
284 session->graph->mu.lock();
285 mutex_lock session_lock(session->mu);
286 const Graph& graph = session->graph->graph;
287
288 const string& mutation_warning = session->graph->sessions[session];
289 if (!mutation_warning.empty()) {
290 // TODO(b/74949947): turn this back into an error status
291 LOG(WARNING) << mutation_warning;
292 session->graph->sessions[session].clear();
293 }
294
295 const auto num_nodes = graph.num_node_ids();
296 if (session->last_num_graph_nodes < num_nodes) {
297 // TODO(nolivia): check this on a subset of the graph instead of all of
298 // it.
299 status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
300 if (!status->status.ok()) {
301 session->graph->mu.unlock();
302 return false;
303 }
304
305 GraphDef graph_def;
306 *graph_def.mutable_versions() = graph.versions();
307 // Fill graph_def with nodes with ids in the range
308 // [session->last_num_graph_nodes, num_nodes), that is the nodes
309 // added since the last TF_SessionRun() call.
310 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
311 Node* const node = graph.FindNodeId(id);
312 if (node != nullptr && node->IsOp()) {
313 NodeDef* const node_def = graph_def.add_node();
314 *node_def = node->def();
315 }
316 }
317 *graph_def.mutable_library() = graph.flib_def().ToProto();
318 session->graph->mu.unlock();
319 status->status = session->session->Extend(std::move(graph_def));
320 if (!status->status.ok()) {
321 // Contract is we always delete input_values[i].
322 return false;
323 }
324 // Note: session->session is not modified if Extend() fails, so
325 // we only set last_num_graph_nodes if it succeeds.
326 session->last_num_graph_nodes = num_nodes;
327 } else {
328 session->graph->mu.unlock();
329 }
330 }
331 return true;
332 }
333
334 } // namespace tensorflow
335
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)336 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
337 TF_Status* status) {
338 status->status = ::tensorflow::OkStatus();
339 for (int i = 0; i < noutputs; ++i) {
340 c_outputs[i] = nullptr;
341 }
342 }
343
344 // TF_TensorToTensorV1 decodes a string serialization to DT_RESOURCE.
345 // In the TFv1 convention, TF_Tensor can hold a string serialization of
346 // DT_RESOURCE. The string serialization is converted back to a
347 // ResourceHandle during Session run where the TF_Tensor is converted to a
348 // Tensor.
349 // TFv2 does not depend on this conversion. There is no matching
350 // TF_TensorFromTensorV1 because the conversion to string is performed by the
351 // python side of Session.
TF_TensorToTensorV1(const TF_Tensor * src,Tensor * dst)352 static Status TF_TensorToTensorV1(const TF_Tensor* src, Tensor* dst) {
353 Status status = TF_TensorToTensor(src, dst);
354 if (!status.ok()) {
355 return status;
356 }
357 if (dst->dtype() == tensorflow::DT_RESOURCE) {
358 const auto tensor_interface =
359 tensorflow::down_cast<const tensorflow::TensorInterface*>(src->tensor);
360
361 if (dst->dims() != 0) {
362 return InvalidArgument(
363 "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with "
364 "shape ",
365 dst->shape().DebugString());
366 }
367 *dst = tensorflow::Tensor(tensorflow::DT_RESOURCE, dst->shape());
368 if (!dst->scalar<tensorflow::ResourceHandle>()().ParseFromString(
369 string(static_cast<const char*>(tensor_interface->Data()),
370 tensor_interface->ByteSize()))) {
371 return InvalidArgument(
372 "Malformed TF_RESOURCE tensor: unable to parse resource handle");
373 }
374 return ::tensorflow::OkStatus();
375 }
376 return ::tensorflow::OkStatus();
377 }
378
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)379 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
380 std::vector<std::pair<string, Tensor>>* input_pairs,
381 TF_Status* status) {
382 const int ninputs = input_pairs->size();
383 for (int i = 0; i < ninputs; ++i) {
384 status->status =
385 TF_TensorToTensorV1(c_inputs[i], &(*input_pairs)[i].second);
386 if (!status->status.ok()) return false;
387 }
388 return true;
389 }
390
391 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
392 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)393 static TF_Tensor* EmptyTensor(TF_DataType dtype,
394 const tensorflow::TensorShape& shape) {
395 static char empty;
396 int64_t nelems = 1;
397 std::vector<int64_t> dims;
398 dims.reserve(shape.dims());
399 for (int i = 0; i < shape.dims(); ++i) {
400 dims.push_back(shape.dim_size(i));
401 nelems *= shape.dim_size(i);
402 }
403 CHECK_EQ(nelems, 0);
404 return TF_NewTensor(
405 dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
406 reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
407 }
408
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)409 static void TF_Run_Helper(
410 Session* session, const char* handle, const TF_Buffer* run_options,
411 // Input tensors
412 const std::vector<std::pair<string, Tensor>>& input_pairs,
413 // Output tensors
414 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
415 // Target nodes
416 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
417 TF_Status* status) {
418 const int noutputs = output_tensor_names.size();
419 std::vector<Tensor> outputs(noutputs);
420 Status result;
421
422 if (handle == nullptr) {
423 RunOptions run_options_proto;
424 if (run_options != nullptr && !run_options_proto.ParseFromArray(
425 run_options->data, run_options->length)) {
426 status->status = InvalidArgument("Unparseable RunOptions proto");
427 return;
428 }
429 if (run_metadata != nullptr && run_metadata->data != nullptr) {
430 status->status =
431 InvalidArgument("Passing non-empty run_metadata is invalid.");
432 return;
433 }
434
435 RunMetadata run_metadata_proto;
436 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
437 target_oper_names, &outputs, &run_metadata_proto);
438
439 // Serialize back to upstream client, who now owns the new buffer
440 if (run_metadata != nullptr) {
441 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
442 if (!status->status.ok()) return;
443 }
444 } else {
445 // NOTE(zongheng): PRun does not support RunOptions yet.
446 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
447 }
448 if (!result.ok()) {
449 status->status = result;
450 return;
451 }
452
453 // Store results in c_outputs[]
454 for (int i = 0; i < noutputs; ++i) {
455 const Tensor& src = outputs[i];
456 if (!src.IsInitialized() || src.NumElements() == 0) {
457 c_outputs[i] =
458 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
459 continue;
460 }
461 c_outputs[i] = TF_TensorFromTensor(src, &status->status);
462 if (!status->status.ok()) return;
463 }
464 }
465
466 extern "C" {
467
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)468 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
469 // Input tensors
470 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
471 // Output tensors
472 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
473 // Target nodes
474 const char** c_target_oper_names, int ntargets,
475 TF_Buffer* run_metadata, TF_Status* status) {
476 TF_Run_Setup(noutputs, c_outputs, status);
477 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
478 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
479 for (int i = 0; i < ninputs; ++i) {
480 input_pairs[i].first = c_input_names[i];
481 }
482 std::vector<string> output_names(noutputs);
483 for (int i = 0; i < noutputs; ++i) {
484 output_names[i] = c_output_names[i];
485 }
486 std::vector<string> target_oper_names(ntargets);
487 for (int i = 0; i < ntargets; ++i) {
488 target_oper_names[i] = c_target_oper_names[i];
489 }
490 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
491 c_outputs, target_oper_names, run_metadata, status);
492 }
493
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)494 void TF_PRunSetup(TF_DeprecatedSession* s,
495 // Input names
496 const char** c_input_names, int ninputs,
497 // Output names
498 const char** c_output_names, int noutputs,
499 // Target nodes
500 const char** c_target_oper_names, int ntargets,
501 const char** handle, TF_Status* status) {
502 *handle = nullptr;
503
504 std::vector<string> input_names(ninputs);
505 std::vector<string> output_names(noutputs);
506 std::vector<string> target_oper_names(ntargets);
507 for (int i = 0; i < ninputs; ++i) {
508 input_names[i] = c_input_names[i];
509 }
510 for (int i = 0; i < noutputs; ++i) {
511 output_names[i] = c_output_names[i];
512 }
513 for (int i = 0; i < ntargets; ++i) {
514 target_oper_names[i] = c_target_oper_names[i];
515 }
516 string new_handle;
517 status->status = s->session->PRunSetup(input_names, output_names,
518 target_oper_names, &new_handle);
519 if (status->status.ok()) {
520 char* buf = new char[new_handle.size() + 1];
521 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
522 *handle = buf;
523 }
524 }
525
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)526 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
527 // Input tensors
528 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
529 // Output tensors
530 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
531 // Target nodes
532 const char** c_target_oper_names, int ntargets,
533 TF_Status* status) {
534 TF_Run_Setup(noutputs, c_outputs, status);
535 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
536 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
537 for (int i = 0; i < ninputs; ++i) {
538 input_pairs[i].first = c_input_names[i];
539 }
540
541 std::vector<string> output_names(noutputs);
542 for (int i = 0; i < noutputs; ++i) {
543 output_names[i] = c_output_names[i];
544 }
545 std::vector<string> target_oper_names(ntargets);
546 for (int i = 0; i < ntargets; ++i) {
547 target_oper_names[i] = c_target_oper_names[i];
548 }
549 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
550 c_outputs, target_oper_names, nullptr, status);
551 }
552
TF_LoadLibrary(const char * library_filename,TF_Status * status)553 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
554 TF_Library* lib_handle = new TF_Library;
555 status->status = tensorflow::LoadDynamicLibrary(
556 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
557 &lib_handle->op_list.length);
558 if (!status->status.ok()) {
559 delete lib_handle;
560 return nullptr;
561 }
562 return lib_handle;
563 }
564
TF_GetOpList(TF_Library * lib_handle)565 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
566
TF_DeleteLibraryHandle(TF_Library * lib_handle)567 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
568 if (lib_handle == nullptr) return;
569 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
570 delete lib_handle;
571 }
572
TF_GetAllOpList()573 TF_Buffer* TF_GetAllOpList() {
574 std::vector<tensorflow::OpDef> op_defs;
575 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
576 tensorflow::OpList op_list;
577 for (const auto& op : op_defs) {
578 *(op_list.add_op()) = op;
579 }
580 TF_Buffer* ret = TF_NewBuffer();
581 TF_CHECK_OK(MessageToBuffer(op_list, ret));
582 return ret;
583 }
584
585 // --------------------------------------------------------------------------
586 // ListDevices & SessionListDevices API
587
TF_DeleteDeviceList(TF_DeviceList * list)588 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
589
TF_SessionListDevices(TF_Session * session,TF_Status * status)590 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
591 TF_DeviceList* response = new TF_DeviceList;
592 if (session && session->session)
593 status->status = session->session->ListDevices(&response->response);
594 return response;
595 }
596
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)597 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
598 TF_Status* status) {
599 TF_DeviceList* response = new TF_DeviceList;
600 if (session && session->session)
601 status->status = session->session->ListDevices(&response->response);
602 return response;
603 }
604
TF_DeviceListCount(const TF_DeviceList * list)605 int TF_DeviceListCount(const TF_DeviceList* list) {
606 return list->response.size();
607 }
608
609 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
610 return_type method_name(const TF_DeviceList* list, const int index, \
611 TF_Status* status) { \
612 if (list == nullptr) { \
613 status->status = InvalidArgument("list is null!"); \
614 return err_val; \
615 } \
616 if (index < 0 || index >= list->response.size()) { \
617 status->status = InvalidArgument("index out of bounds"); \
618 return err_val; \
619 } \
620 status->status = ::tensorflow::OkStatus(); \
621 return list->response[index].accessor; \
622 }
623
624 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
625 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
626 nullptr);
627 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
628 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
629
630 #undef TF_DEVICELIST_METHOD
631
632 } // end extern "C"
633
634 // --------------------------------------------------------------------------
635 // New Graph and Session API
636
637 // Helper functions -----------------------------------------------------------
638
639 namespace {
640
ToOperation(Node * node)641 TF_Operation* ToOperation(Node* node) {
642 return static_cast<TF_Operation*>(static_cast<void*>(node));
643 }
644
OutputName(const TF_Output & output)645 string OutputName(const TF_Output& output) {
646 return StrCat(output.oper->node.name(), ":", output.index);
647 }
648
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)649 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
650 const char* attr_name,
651 TF_Status* status) {
652 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
653 if (attr == nullptr) {
654 status->status = InvalidArgument("Operation '", oper->node.name(),
655 "' has no attr named '", attr_name, "'.");
656 }
657 return attr;
658 }
659
ToTensorId(const TF_Output & output)660 TensorId ToTensorId(const TF_Output& output) {
661 return TensorId(output.oper->node.name(), output.index);
662 }
663
664 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)665 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
666 int n) {
667 std::vector<tensorflow::Output> outputs(n);
668 for (int i = 0; i < n; ++i) {
669 outputs[i] =
670 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
671 }
672 return outputs;
673 }
674
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)675 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
676 TF_Output* tf_outputs) {
677 for (int i = 0; i < outputs.size(); i++) {
678 tf_outputs[i].oper = ToOperation(outputs[i].node());
679 tf_outputs[i].index = outputs[i].index();
680 }
681 }
682 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
683
684 } // namespace
685
686 // Shape functions -----------------------------------------------------------
687
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)688 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
689 const int64_t* dims, const int num_dims,
690 TF_Status* status) {
691 Node* node = &output.oper->node;
692
693 mutex_lock l(graph->mu);
694 tensorflow::shape_inference::InferenceContext* ic =
695 graph->refiner.GetContext(node);
696 if (ic == nullptr) {
697 status->status =
698 InvalidArgument("Node ", node->name(), " was not found in the graph");
699 return;
700 }
701 tensorflow::shape_inference::ShapeHandle new_shape =
702 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
703 status->status = graph->refiner.SetShape(node, output.index, new_shape);
704 }
705
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)706 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
707 TF_Status* status) {
708 Node* node = &output.oper->node;
709
710 mutex_lock l(graph->mu);
711 tensorflow::shape_inference::InferenceContext* ic =
712 graph->refiner.GetContext(node);
713 if (ic == nullptr) {
714 status->status =
715 InvalidArgument("Node ", node->name(), " was not found in the graph");
716 return -1;
717 }
718
719 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
720
721 // Unknown rank means the number of dimensions is -1.
722 if (!ic->RankKnown(shape)) {
723 return -1;
724 }
725
726 return ic->Rank(shape);
727 }
728
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)729 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
730 int num_dims, TF_Status* status) {
731 Node* node = &output.oper->node;
732
733 mutex_lock l(graph->mu);
734 tensorflow::shape_inference::InferenceContext* ic =
735 graph->refiner.GetContext(node);
736 if (ic == nullptr) {
737 status->status =
738 InvalidArgument("Node ", node->name(), " was not found in the graph");
739 return;
740 }
741
742 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
743
744 int rank = -1;
745 if (ic->RankKnown(shape)) {
746 rank = ic->Rank(shape);
747 }
748
749 if (num_dims != rank) {
750 status->status = InvalidArgument("Expected rank is ", num_dims,
751 " but actual rank is ", rank);
752 return;
753 }
754
755 if (num_dims == 0) {
756 // Output shape is a scalar.
757 return;
758 }
759
760 // Rank is greater than 0, so fill in the values, if known, and
761 // -1 for unknown values.
762 for (int i = 0; i < num_dims; ++i) {
763 auto dim = ic->Dim(shape, i);
764 int64_t value = -1;
765 if (ic->ValueKnown(dim)) {
766 value = ic->Value(dim);
767 }
768 dims[i] = value;
769 }
770 }
771
772 // TF_OperationDescription functions ------------------------------------------
773
774 extern "C" {
775
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)776 TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
777 const char* op_type,
778 const char* oper_name)
779 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
780 return new TF_OperationDescription(graph, op_type, oper_name);
781 }
782
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)783 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
784 const char* oper_name) {
785 mutex_lock l(graph->mu);
786 return TF_NewOperationLocked(graph, op_type, oper_name);
787 }
788
TF_SetDevice(TF_OperationDescription * desc,const char * device)789 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
790 desc->node_builder.Device(device);
791 }
792
TF_AddInput(TF_OperationDescription * desc,TF_Output input)793 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
794 desc->node_builder.Input(&input.oper->node, input.index);
795 }
796
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)797 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
798 int num_inputs) {
799 std::vector<NodeBuilder::NodeOut> input_list;
800 input_list.reserve(num_inputs);
801 for (int i = 0; i < num_inputs; ++i) {
802 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
803 }
804 desc->node_builder.Input(input_list);
805 }
806
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)807 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
808 desc->node_builder.ControlInput(&input->node);
809 }
810
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)811 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
812 desc->colocation_constraints.emplace(
813 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
814 }
815
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)816 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
817 const void* value, size_t length) {
818 tensorflow::StringPiece s(static_cast<const char*>(value), length);
819 desc->node_builder.Attr(attr_name, s);
820 }
821
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)822 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
823 const void* const* values, const size_t* lengths,
824 int num_values) {
825 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
826 desc->colocation_constraints.clear();
827 for (int i = 0; i < num_values; ++i) {
828 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
829 lengths[i]);
830 }
831 } else {
832 std::vector<tensorflow::StringPiece> v;
833 v.reserve(num_values);
834 for (int i = 0; i < num_values; ++i) {
835 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
836 }
837 desc->node_builder.Attr(attr_name, v);
838 }
839 }
840
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)841 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
842 int64_t value) {
843 desc->node_builder.Attr(attr_name, static_cast<int64_t>(value));
844 }
845
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)846 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
847 const int64_t* values, int num_values) {
848 desc->node_builder.Attr(
849 attr_name, ArraySlice<const int64_t>(
850 reinterpret_cast<const int64_t*>(values), num_values));
851 }
852
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)853 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
854 float value) {
855 desc->node_builder.Attr(attr_name, value);
856 }
857
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)858 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
859 const float* values, int num_values) {
860 desc->node_builder.Attr(attr_name,
861 ArraySlice<const float>(values, num_values));
862 }
863
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)864 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
865 unsigned char value) {
866 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
867 }
868
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)869 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
870 const unsigned char* values, int num_values) {
871 std::unique_ptr<bool[]> b(new bool[num_values]);
872 for (int i = 0; i < num_values; ++i) {
873 b[i] = values[i];
874 }
875 desc->node_builder.Attr(attr_name,
876 ArraySlice<const bool>(b.get(), num_values));
877 }
878
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)879 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
880 TF_DataType value) {
881 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
882 }
883
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)884 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
885 const TF_DataType* values, int num_values) {
886 desc->node_builder.Attr(
887 attr_name, ArraySlice<const DataType>(
888 reinterpret_cast<const DataType*>(values), num_values));
889 }
890
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)891 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
892 const char* placeholder) {
893 tensorflow::AttrValue attr_value;
894 attr_value.set_placeholder(placeholder);
895 desc->node_builder.Attr(attr_name, attr_value);
896 }
897
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)898 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
899 const char* value, size_t length) {
900 tensorflow::NameAttrList func_name;
901 func_name.set_name(string(value, value + length));
902 desc->node_builder.Attr(attr_name, func_name);
903 }
904
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)905 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
906 const int64_t* dims, int num_dims) {
907 PartialTensorShape shape;
908 if (num_dims >= 0) {
909 shape = PartialTensorShape(
910 ArraySlice<int64_t>(reinterpret_cast<const int64_t*>(dims), num_dims));
911 }
912 desc->node_builder.Attr(attr_name, shape);
913 }
914
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)915 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
916 const int64_t* const* dims, const int* num_dims,
917 int num_shapes) {
918 std::vector<PartialTensorShape> shapes;
919 shapes.reserve(num_shapes);
920 for (int i = 0; i < num_shapes; ++i) {
921 if (num_dims[i] < 0) {
922 shapes.emplace_back();
923 } else {
924 shapes.emplace_back(ArraySlice<int64_t>(
925 reinterpret_cast<const int64_t*>(dims[i]), num_dims[i]));
926 }
927 }
928 desc->node_builder.Attr(attr_name, shapes);
929 }
930
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)931 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
932 const char* attr_name, const void* proto,
933 size_t proto_len, TF_Status* status) {
934 // shape.ParseFromArray takes an int as length, this function takes size_t,
935 // make sure there is no information loss.
936 if (proto_len > std::numeric_limits<int>::max()) {
937 status->status = InvalidArgument(
938 "proto_len (", proto_len,
939 " bytes) is too large to be parsed by the protocol buffer library");
940 return;
941 }
942 TensorShapeProto shape;
943 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
944 desc->node_builder.Attr(attr_name, shape);
945 status->status = ::tensorflow::OkStatus();
946 } else {
947 status->status = InvalidArgument("Unparseable TensorShapeProto");
948 }
949 }
950
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)951 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
952 const char* attr_name,
953 const void* const* protos,
954 const size_t* proto_lens, int num_shapes,
955 TF_Status* status) {
956 std::vector<TensorShapeProto> shapes;
957 shapes.resize(num_shapes);
958 for (int i = 0; i < num_shapes; ++i) {
959 if (proto_lens[i] > std::numeric_limits<int>::max()) {
960 status->status = InvalidArgument(
961 "length of element ", i, " in the list (", proto_lens[i],
962 " bytes) is too large to be parsed by the protocol buffer library");
963 return;
964 }
965 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
966 status->status =
967 InvalidArgument("Unparseable TensorShapeProto at index ", i);
968 return;
969 }
970 }
971 desc->node_builder.Attr(attr_name, shapes);
972 status->status = ::tensorflow::OkStatus();
973 }
974
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)975 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
976 TF_Tensor* value, TF_Status* status) {
977 Tensor t;
978 status->status = TF_TensorToTensor(value, &t);
979 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
980 }
981
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)982 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
983 TF_Tensor* const* values, int num_values,
984 TF_Status* status) {
985 status->status = ::tensorflow::OkStatus();
986 std::vector<Tensor> t;
987 t.reserve(num_values);
988
989 for (int i = 0; i < num_values && status->status.ok(); ++i) {
990 Tensor v;
991 status->status = TF_TensorToTensor(values[i], &v);
992 t.emplace_back(v);
993 }
994
995 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
996 }
997
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)998 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
999 const void* proto, size_t proto_len,
1000 TF_Status* status) {
1001 tensorflow::AttrValue attr_value;
1002 if (!attr_value.ParseFromArray(proto, proto_len)) {
1003 status->status = InvalidArgument("Unparseable AttrValue proto");
1004 return;
1005 }
1006
1007 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1008 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1009 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1010 status->status =
1011 InvalidArgument("Expected \"list\" field for \"",
1012 tensorflow::kColocationAttrName, "\" attribute");
1013 return;
1014 }
1015 desc->colocation_constraints.clear();
1016 for (const string& location : attr_value.list().s()) {
1017 desc->colocation_constraints.insert(location);
1018 }
1019 } else {
1020 desc->node_builder.Attr(attr_name, std::move(attr_value));
1021 }
1022
1023 status->status = ::tensorflow::OkStatus();
1024 }
1025
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1026 TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1027 TF_Status* status)
1028 TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1029 Node* ret = nullptr;
1030
1031 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1032 status->status = InvalidArgument("Duplicate node name in graph: '",
1033 desc->node_builder.node_name(), "'");
1034 } else {
1035 if (!desc->colocation_constraints.empty()) {
1036 desc->node_builder.Attr(
1037 tensorflow::kColocationAttrName,
1038 std::vector<string>(desc->colocation_constraints.begin(),
1039 desc->colocation_constraints.end()));
1040 }
1041 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1042 /*consume=*/true);
1043
1044 if (status->status.ok()) {
1045 // Run shape inference function for newly added node.
1046 status->status = desc->graph->refiner.AddNode(ret);
1047 }
1048 if (status->status.ok()) {
1049 // Add the node to the name-to-node mapping.
1050 desc->graph->name_map[ret->name()] = ret;
1051 } else if (ret != nullptr) {
1052 desc->graph->graph.RemoveNode(ret);
1053 ret = nullptr;
1054 }
1055 }
1056
1057 delete desc;
1058
1059 return ToOperation(ret);
1060 }
1061
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1062 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1063 TF_Status* status) {
1064 mutex_lock l(desc->graph->mu);
1065 return TF_FinishOperationLocked(desc, status);
1066 }
1067
1068 // TF_Operation functions
1069 // ----------------------------------------------------------
1070
TF_OperationName(TF_Operation * oper)1071 const char* TF_OperationName(TF_Operation* oper) {
1072 return oper->node.name().c_str();
1073 }
1074
TF_OperationOpType(TF_Operation * oper)1075 const char* TF_OperationOpType(TF_Operation* oper) {
1076 return oper->node.type_string().c_str();
1077 }
1078
TF_OperationDevice(TF_Operation * oper)1079 const char* TF_OperationDevice(TF_Operation* oper) {
1080 return oper->node.requested_device().c_str();
1081 }
1082
TF_OperationNumOutputs(TF_Operation * oper)1083 int TF_OperationNumOutputs(TF_Operation* oper) {
1084 return oper->node.num_outputs();
1085 }
1086
TF_OperationOutputType(TF_Output oper_out)1087 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1088 return static_cast<TF_DataType>(
1089 oper_out.oper->node.output_type(oper_out.index));
1090 }
1091
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1092 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1093 TF_Status* status) {
1094 NameRangeMap name_ranges;
1095 status->status =
1096 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1097 if (!status->status.ok()) return -1;
1098 auto iter = name_ranges.find(arg_name);
1099 if (iter == name_ranges.end()) {
1100 status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1101 return -1;
1102 }
1103 return iter->second.second - iter->second.first;
1104 }
1105
TF_OperationNumInputs(TF_Operation * oper)1106 int TF_OperationNumInputs(TF_Operation* oper) {
1107 return oper->node.num_inputs();
1108 }
1109
TF_OperationInputType(TF_Input oper_in)1110 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1111 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1112 }
1113
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1114 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1115 TF_Status* status) {
1116 NameRangeMap name_ranges;
1117 status->status =
1118 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1119 if (!status->status.ok()) return -1;
1120 auto iter = name_ranges.find(arg_name);
1121 if (iter == name_ranges.end()) {
1122 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1123 return -1;
1124 }
1125 return iter->second.second - iter->second.first;
1126 }
1127
TF_OperationInput(TF_Input oper_in)1128 TF_Output TF_OperationInput(TF_Input oper_in) {
1129 const tensorflow::Edge* edge;
1130 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1131 if (!s.ok()) {
1132 return {nullptr, -1};
1133 }
1134
1135 return {ToOperation(edge->src()), edge->src_output()};
1136 }
1137
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1138 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1139 int max_inputs) {
1140 for (auto* edge : oper->node.in_edges()) {
1141 if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1142 inputs[edge->dst_input()] = {ToOperation(edge->src()),
1143 edge->src_output()};
1144 }
1145 }
1146 }
1147
TF_OperationOutputNumConsumers(TF_Output oper_out)1148 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1149 int count = 0;
1150 for (const auto* edge : oper_out.oper->node.out_edges()) {
1151 if (edge->src_output() == oper_out.index) {
1152 ++count;
1153 }
1154 }
1155 return count;
1156 }
1157
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1158 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1159 int max_consumers) {
1160 int count = 0;
1161 for (const auto* edge : oper_out.oper->node.out_edges()) {
1162 if (edge->src_output() == oper_out.index) {
1163 if (count < max_consumers) {
1164 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1165 }
1166 ++count;
1167 }
1168 }
1169 return count;
1170 }
1171
TF_OperationNumControlInputs(TF_Operation * oper)1172 int TF_OperationNumControlInputs(TF_Operation* oper) {
1173 int count = 0;
1174 for (const auto* edge : oper->node.in_edges()) {
1175 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1176 ++count;
1177 }
1178 }
1179 return count;
1180 }
1181
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1182 int TF_OperationGetControlInputs(TF_Operation* oper,
1183 TF_Operation** control_inputs,
1184 int max_control_inputs) {
1185 int count = 0;
1186 for (const auto* edge : oper->node.in_edges()) {
1187 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1188 if (count < max_control_inputs) {
1189 control_inputs[count] = ToOperation(edge->src());
1190 }
1191 ++count;
1192 }
1193 }
1194 return count;
1195 }
1196
TF_OperationNumControlOutputs(TF_Operation * oper)1197 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1198 int count = 0;
1199 for (const auto* edge : oper->node.out_edges()) {
1200 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1201 ++count;
1202 }
1203 }
1204 return count;
1205 }
1206
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1207 int TF_OperationGetControlOutputs(TF_Operation* oper,
1208 TF_Operation** control_outputs,
1209 int max_control_outputs) {
1210 int count = 0;
1211 for (const auto* edge : oper->node.out_edges()) {
1212 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1213 if (count < max_control_outputs) {
1214 control_outputs[count] = ToOperation(edge->dst());
1215 }
1216 ++count;
1217 }
1218 }
1219 return count;
1220 }
1221
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1222 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1223 const char* attr_name,
1224 TF_Status* status) {
1225 TF_AttrMetadata metadata;
1226 const auto* attr = GetAttrValue(oper, attr_name, status);
1227 if (!status->status.ok()) return metadata;
1228 switch (attr->value_case()) {
1229 #define SINGLE_CASE(kK, attr_type, size_expr) \
1230 case tensorflow::AttrValue::kK: \
1231 metadata.is_list = 0; \
1232 metadata.list_size = -1; \
1233 metadata.type = attr_type; \
1234 metadata.total_size = size_expr; \
1235 break;
1236
1237 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1238 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1239 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1240 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1241 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1242 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1243 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1244 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1245 #undef SINGLE_CASE
1246
1247 case tensorflow::AttrValue::kList:
1248 metadata.is_list = 1;
1249 metadata.list_size = 0;
1250 metadata.total_size = -1;
1251 #define LIST_CASE(field, attr_type, ...) \
1252 if (attr->list().field##_size() > 0) { \
1253 metadata.type = attr_type; \
1254 metadata.list_size = attr->list().field##_size(); \
1255 __VA_ARGS__; \
1256 break; \
1257 }
1258
1259 LIST_CASE(
1260 s, TF_ATTR_STRING, metadata.total_size = 0;
1261 for (int i = 0; i < attr->list().s_size();
1262 ++i) { metadata.total_size += attr->list().s(i).size(); });
1263 LIST_CASE(i, TF_ATTR_INT);
1264 LIST_CASE(f, TF_ATTR_FLOAT);
1265 LIST_CASE(b, TF_ATTR_BOOL);
1266 LIST_CASE(type, TF_ATTR_TYPE);
1267 LIST_CASE(
1268 shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1269 for (int i = 0; i < attr->list().shape_size(); ++i) {
1270 const auto& s = attr->list().shape(i);
1271 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1272 });
1273 LIST_CASE(tensor, TF_ATTR_TENSOR);
1274 LIST_CASE(tensor, TF_ATTR_FUNC);
1275 #undef LIST_CASE
1276 // All lists empty, determine the type from the OpDef.
1277 if (metadata.list_size == 0) {
1278 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1279 const auto& a = oper->node.op_def().attr(i);
1280 if (a.name() != attr_name) continue;
1281 const string& typestr = a.type();
1282 if (typestr == "list(string)") {
1283 metadata.type = TF_ATTR_STRING;
1284 } else if (typestr == "list(int)") {
1285 metadata.type = TF_ATTR_INT;
1286 } else if (typestr == "list(float)") {
1287 metadata.type = TF_ATTR_FLOAT;
1288 } else if (typestr == "list(bool)") {
1289 metadata.type = TF_ATTR_BOOL;
1290 } else if (typestr == "list(type)") {
1291 metadata.type = TF_ATTR_TYPE;
1292 } else if (typestr == "list(shape)") {
1293 metadata.type = TF_ATTR_SHAPE;
1294 } else if (typestr == "list(tensor)") {
1295 metadata.type = TF_ATTR_TENSOR;
1296 } else if (typestr == "list(func)") {
1297 metadata.type = TF_ATTR_FUNC;
1298 } else {
1299 status->status = InvalidArgument(
1300 "Attribute '", attr_name,
1301 "' has an empty value of an unrecognized type '", typestr, "'");
1302 return metadata;
1303 }
1304 }
1305 }
1306 break;
1307
1308 case tensorflow::AttrValue::kPlaceholder:
1309 metadata.is_list = 0;
1310 metadata.list_size = -1;
1311 metadata.type = TF_ATTR_PLACEHOLDER;
1312 metadata.total_size = -1;
1313 break;
1314
1315 case tensorflow::AttrValue::kFunc:
1316 metadata.is_list = 0;
1317 metadata.list_size = -1;
1318 metadata.type = TF_ATTR_FUNC;
1319 metadata.total_size = -1;
1320 break;
1321
1322 case tensorflow::AttrValue::VALUE_NOT_SET:
1323 status->status =
1324 InvalidArgument("Attribute '", attr_name, "' has no value set");
1325 break;
1326 }
1327 return metadata;
1328 }
1329
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1330 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1331 void* value, size_t max_length,
1332 TF_Status* status) {
1333 const auto* attr = GetAttrValue(oper, attr_name, status);
1334 if (!status->status.ok()) return;
1335 if (attr->value_case() != tensorflow::AttrValue::kS) {
1336 status->status =
1337 InvalidArgument("Attribute '", attr_name, "' is not a string");
1338 return;
1339 }
1340 if (max_length <= 0) {
1341 return;
1342 }
1343 const auto& s = attr->s();
1344 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1345 }
1346
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1347 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1348 void** values, size_t* lengths,
1349 int max_values, void* storage,
1350 size_t storage_size, TF_Status* status) {
1351 const auto* attr = GetAttrValue(oper, attr_name, status);
1352 if (!status->status.ok()) return;
1353 if (attr->value_case() != tensorflow::AttrValue::kList) {
1354 status->status =
1355 InvalidArgument("Value for '", attr_name, "' is not a list");
1356 return;
1357 }
1358 const auto len = std::min(max_values, attr->list().s_size());
1359 char* p = static_cast<char*>(storage);
1360 for (int i = 0; i < len; ++i) {
1361 const string& s = attr->list().s(i);
1362 values[i] = p;
1363 lengths[i] = s.size();
1364 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1365 status->status = InvalidArgument(
1366 "Not enough storage to hold the requested list of strings");
1367 return;
1368 }
1369 memcpy(values[i], s.data(), s.size());
1370 p += s.size();
1371 }
1372 }
1373
1374 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1375 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1376 TF_Status* status) { \
1377 cpp_type v; \
1378 status->status = \
1379 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1380 if (!status->status.ok()) return; \
1381 *value = static_cast<c_type>(v); \
1382 } \
1383 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1384 int max_values, TF_Status* status) { \
1385 const auto* attr = GetAttrValue(oper, attr_name, status); \
1386 if (!status->status.ok()) return; \
1387 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1388 status->status = \
1389 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1390 return; \
1391 } \
1392 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1393 for (int i = 0; i < len; ++i) { \
1394 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1395 } \
1396 }
1397 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, int64_t, i);
1398 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1399 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1400 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1401 #undef DEFINE_GETATTR
1402
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1403 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1404 int64_t* value, int num_dims, TF_Status* status) {
1405 PartialTensorShape shape;
1406 status->status =
1407 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1408 if (!status->status.ok()) return;
1409 auto len = std::min(shape.dims(), num_dims);
1410 for (int i = 0; i < len; ++i) {
1411 value[i] = shape.dim_size(i);
1412 }
1413 }
1414
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1415 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1416 int64_t** dims, int* num_dims, int num_shapes,
1417 int64_t* storage, int storage_size,
1418 TF_Status* status) {
1419 std::vector<PartialTensorShape> shapes;
1420 status->status =
1421 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1422 if (!status->status.ok()) return;
1423 auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1424 int64_t* p = storage;
1425 int storage_left = storage_size;
1426 for (int i = 0; i < len; ++i) {
1427 // shapes[i].dims() == -1 for shapes with an unknown rank.
1428 int64_t n = shapes[i].dims();
1429 num_dims[i] = n;
1430 dims[i] = p;
1431 if (n < 0) {
1432 continue;
1433 }
1434 if (storage_left < n) {
1435 status->status = InvalidArgument(
1436 "Not enough storage to hold the requested list of shapes");
1437 return;
1438 }
1439 storage_left -= n;
1440 for (int j = 0; j < n; ++j, ++p) {
1441 *p = shapes[i].dim_size(j);
1442 }
1443 }
1444 }
1445
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1446 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1447 const char* attr_name,
1448 TF_Buffer* value, TF_Status* status) {
1449 const auto* attr = GetAttrValue(oper, attr_name, status);
1450 if (!status->status.ok()) return;
1451 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1452 status->status =
1453 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1454 return;
1455 }
1456 status->status = MessageToBuffer(attr->shape(), value);
1457 }
1458
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1459 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1460 const char* attr_name,
1461 TF_Buffer** values, int max_values,
1462 TF_Status* status) {
1463 const auto* attr = GetAttrValue(oper, attr_name, status);
1464 if (!status->status.ok()) return;
1465 if (attr->value_case() != tensorflow::AttrValue::kList) {
1466 status->status =
1467 InvalidArgument("Value for '", attr_name, "' is not a list");
1468 return;
1469 }
1470 const auto len = std::min(max_values, attr->list().shape_size());
1471 for (int i = 0; i < len; ++i) {
1472 values[i] = TF_NewBuffer();
1473 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1474 if (!status->status.ok()) {
1475 // Delete everything allocated to far, the operation has failed.
1476 for (int j = 0; j <= i; ++j) {
1477 TF_DeleteBuffer(values[j]);
1478 }
1479 return;
1480 }
1481 }
1482 }
1483
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1484 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1485 TF_Tensor** value, TF_Status* status) {
1486 *value = nullptr;
1487 Tensor t;
1488 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1489 if (!status->status.ok()) return;
1490 *value = TF_TensorFromTensor(t, &status->status);
1491 }
1492
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1493 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1494 TF_Tensor** values, int max_values,
1495 TF_Status* status) {
1496 std::vector<Tensor> ts;
1497 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1498 if (!status->status.ok()) return;
1499 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1500 for (int i = 0; i < len; ++i) {
1501 values[i] = TF_TensorFromTensor(ts[i], &status->status);
1502 }
1503 }
1504
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1505 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1506 TF_Buffer* output_attr_value,
1507 TF_Status* status) {
1508 const auto* attr = GetAttrValue(oper, attr_name, status);
1509 if (!status->status.ok()) return;
1510 status->status = MessageToBuffer(*attr, output_attr_value);
1511 }
1512
TF_OperationGetNumAttrs(TF_Operation * oper)1513 int TF_OperationGetNumAttrs(TF_Operation* oper) {
1514 return oper->node.attrs().size();
1515 }
1516
TF_OperationGetAttrNameLength(TF_Operation * oper,int i)1517 int TF_OperationGetAttrNameLength(TF_Operation* oper, int i) {
1518 auto attrs = oper->node.attrs();
1519 int count = 0;
1520 AttrValueMap::const_iterator it;
1521 for (it = attrs.begin(); it != attrs.end(); it++) {
1522 if (count == i) {
1523 return it->first.length();
1524 }
1525 count++;
1526 }
1527 return -1;
1528 }
1529
TF_OperationGetAttrName(TF_Operation * oper,int i,char * output,TF_Status * status)1530 void TF_OperationGetAttrName(TF_Operation* oper, int i, char* output,
1531 TF_Status* status) {
1532 auto attrs = oper->node.attrs();
1533 int count = 0;
1534 AttrValueMap::const_iterator it;
1535 for (it = attrs.begin(); it != attrs.end(); it++) {
1536 if (count == i) {
1537 strncpy(output, it->first.c_str(), it->first.length());
1538 status->status = ::tensorflow::OkStatus();
1539 return;
1540 }
1541 count++;
1542 }
1543 status->status = OutOfRange("Operation only has ", count,
1544 " attributes, can't get the ", i, "th");
1545 }
1546
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1547 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1548 TF_Status* status) {
1549 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1550 }
1551
1552 // TF_Graph functions ---------------------------------------------------------
1553
TF_Graph()1554 TF_Graph::TF_Graph()
1555 : graph(tensorflow::OpRegistry::Global()),
1556 refiner(graph.versions().producer(), graph.op_registry()),
1557 delete_requested(false),
1558 parent(nullptr),
1559 parent_inputs(nullptr) {
1560 // Tell the shape refiner to also run shape inference on functions.
1561 refiner.set_function_library_for_shape_inference(&graph.flib_def());
1562 }
1563
TF_NewGraph()1564 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1565
TF_DeleteGraph(TF_Graph * g)1566 void TF_DeleteGraph(TF_Graph* g) {
1567 if (g == nullptr) return;
1568 g->mu.lock();
1569 g->delete_requested = true;
1570 const bool del = g->sessions.empty();
1571 g->mu.unlock();
1572 if (del) delete g;
1573 }
1574
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1575 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1576 mutex_lock l(graph->mu);
1577 auto iter = graph->name_map.find(oper_name);
1578 if (iter == graph->name_map.end()) {
1579 return nullptr;
1580 } else {
1581 return ToOperation(iter->second);
1582 }
1583 }
1584
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1585 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1586 if (*pos == 0) {
1587 // Advance past the first sentinel nodes in every graph (the source & sink).
1588 *pos += 2;
1589 } else {
1590 // Advance to the next node.
1591 *pos += 1;
1592 }
1593
1594 mutex_lock l(graph->mu);
1595 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1596 Node* node = graph->graph.FindNodeId(*pos);
1597 // FindNodeId() returns nullptr for nodes that have been deleted.
1598 // We aren't currently allowing nodes to be deleted, but it is safer
1599 // to still check.
1600 if (node != nullptr) return ToOperation(node);
1601 *pos += 1;
1602 }
1603
1604 // No more nodes.
1605 return nullptr;
1606 }
1607
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1608 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1609 TF_Status* status) {
1610 GraphDef def;
1611 {
1612 mutex_lock l(graph->mu);
1613 graph->graph.ToGraphDef(&def);
1614 }
1615 status->status = MessageToBuffer(def, output_graph_def);
1616 }
1617
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1618 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1619 TF_Buffer* output_op_def, TF_Status* status) {
1620 const OpDef* op_def;
1621 {
1622 mutex_lock l(graph->mu);
1623 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1624 if (!status->status.ok()) return;
1625 }
1626 status->status = MessageToBuffer(*op_def, output_op_def);
1627 }
1628
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1629 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1630 TF_Status* status) {
1631 VersionDef versions;
1632 {
1633 mutex_lock l(graph->mu);
1634 versions = graph->graph.versions();
1635 }
1636 status->status = MessageToBuffer(versions, output_version_def);
1637 }
1638
TF_NewImportGraphDefOptions()1639 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1640 return new TF_ImportGraphDefOptions;
1641 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1642 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1643 delete opts;
1644 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1645 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1646 const char* prefix) {
1647 opts->opts.prefix = prefix;
1648 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1649 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1650 const char* device) {
1651 opts->opts.default_device = device;
1652 }
1653
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1654 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1655 unsigned char uniquify_names) {
1656 opts->opts.uniquify_names = uniquify_names;
1657 }
1658
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1659 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1660 unsigned char uniquify_prefix) {
1661 opts->opts.uniquify_prefix = uniquify_prefix;
1662 }
1663
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1664 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1665 const char* src_name,
1666 int src_index, TF_Output dst) {
1667 opts->tensor_id_data.push_back(src_name);
1668 const string& src_name_str = opts->tensor_id_data.back();
1669 // We don't need to store dst's name in tensor_id_data, since `dst` must
1670 // outlive the ImportGraphDef call.
1671 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1672 }
1673
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1674 void TF_ImportGraphDefOptionsRemapControlDependency(
1675 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1676 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1677 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1678 }
1679
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1680 extern void TF_ImportGraphDefOptionsAddControlDependency(
1681 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1682 opts->opts.control_dependencies.push_back(oper->node.name());
1683 }
1684
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1685 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1686 const char* oper_name, int index) {
1687 opts->tensor_id_data.push_back(oper_name);
1688 const string& oper_name_str = opts->tensor_id_data.back();
1689 opts->opts.return_tensors.emplace_back(oper_name_str, index);
1690 }
1691
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1692 int TF_ImportGraphDefOptionsNumReturnOutputs(
1693 const TF_ImportGraphDefOptions* opts) {
1694 return opts->opts.return_tensors.size();
1695 }
1696
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1697 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1698 const char* oper_name) {
1699 opts->opts.return_nodes.push_back(oper_name);
1700 }
1701
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1702 int TF_ImportGraphDefOptionsNumReturnOperations(
1703 const TF_ImportGraphDefOptions* opts) {
1704 return opts->opts.return_nodes.size();
1705 }
1706
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1707 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1708 int* num_outputs,
1709 TF_Output** outputs) {
1710 *num_outputs = results->return_tensors.size();
1711 *outputs = results->return_tensors.data();
1712 }
1713
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1714 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1715 int* num_opers,
1716 TF_Operation*** opers) {
1717 *num_opers = results->return_nodes.size();
1718 *opers = results->return_nodes.data();
1719 }
1720
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1721 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1722 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1723 const char*** src_names, int** src_indexes) {
1724 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1725 *src_names = results->missing_unused_key_names.data();
1726 *src_indexes = results->missing_unused_key_indexes.data();
1727 }
1728
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1729 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1730 delete results;
1731 }
1732
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1733 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1734 const TF_ImportGraphDefOptions* opts,
1735 TF_ImportGraphDefResults* tf_results,
1736 TF_Status* status)
1737 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1738 const int last_node_id = graph->graph.num_node_ids();
1739 tensorflow::ImportGraphDefResults results;
1740 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1741 &graph->refiner, &results);
1742 if (!status->status.ok()) return;
1743
1744 // Add new nodes to name_map
1745 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1746 auto* node = graph->graph.FindNodeId(i);
1747 if (node != nullptr) graph->name_map[node->name()] = node;
1748 }
1749
1750 // Populate return_tensors
1751 DCHECK(tf_results->return_tensors.empty());
1752 tf_results->return_tensors.resize(results.return_tensors.size());
1753 for (int i = 0; i < results.return_tensors.size(); ++i) {
1754 tf_results->return_tensors[i].oper =
1755 ToOperation(results.return_tensors[i].first);
1756 tf_results->return_tensors[i].index = results.return_tensors[i].second;
1757 }
1758
1759 // Populate return_nodes
1760 DCHECK(tf_results->return_nodes.empty());
1761 tf_results->return_nodes.resize(results.return_nodes.size());
1762 for (int i = 0; i < results.return_nodes.size(); ++i) {
1763 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1764 }
1765
1766 // Populate missing unused map keys
1767 DCHECK(tf_results->missing_unused_key_names.empty());
1768 DCHECK(tf_results->missing_unused_key_indexes.empty());
1769 DCHECK(tf_results->missing_unused_key_names_data.empty());
1770
1771 size_t size = results.missing_unused_input_map_keys.size();
1772 tf_results->missing_unused_key_names.resize(size);
1773 tf_results->missing_unused_key_indexes.resize(size);
1774
1775 for (int i = 0; i < size; ++i) {
1776 TensorId id = results.missing_unused_input_map_keys[i];
1777 tf_results->missing_unused_key_names_data.emplace_back(id.first);
1778 tf_results->missing_unused_key_names[i] =
1779 tf_results->missing_unused_key_names_data.back().c_str();
1780 tf_results->missing_unused_key_indexes[i] = id.second;
1781 }
1782 }
1783
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1784 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1785 TF_Graph* graph, const TF_Buffer* graph_def,
1786 const TF_ImportGraphDefOptions* options, TF_Status* status) {
1787 GraphDef def;
1788 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1789 graph_def->length)) {
1790 status->status = InvalidArgument("Invalid GraphDef");
1791 return nullptr;
1792 }
1793 auto results = new TF_ImportGraphDefResults();
1794 mutex_lock l(graph->mu);
1795 GraphImportGraphDefLocked(graph, def, options, results, status);
1796 if (!status->status.ok()) {
1797 delete results;
1798 return nullptr;
1799 }
1800 return results;
1801 }
1802
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1803 void TF_GraphImportGraphDefWithReturnOutputs(
1804 TF_Graph* graph, const TF_Buffer* graph_def,
1805 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1806 int num_return_outputs, TF_Status* status) {
1807 if (num_return_outputs != options->opts.return_tensors.size()) {
1808 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1809 options->opts.return_tensors.size(),
1810 ", got ", num_return_outputs);
1811 return;
1812 }
1813 if (num_return_outputs > 0 && return_outputs == nullptr) {
1814 status->status = InvalidArgument(
1815 "'return_outputs' must be preallocated to length ", num_return_outputs);
1816 return;
1817 }
1818 GraphDef def;
1819 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1820 graph_def->length)) {
1821 status->status = InvalidArgument("Invalid GraphDef");
1822 return;
1823 }
1824 TF_ImportGraphDefResults results;
1825 mutex_lock l(graph->mu);
1826 GraphImportGraphDefLocked(graph, def, options, &results, status);
1827 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1828 memcpy(return_outputs, results.return_tensors.data(),
1829 num_return_outputs * sizeof(TF_Output));
1830 }
1831
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1832 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1833 const TF_ImportGraphDefOptions* options,
1834 TF_Status* status) {
1835 TF_ImportGraphDefResults* results =
1836 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1837 TF_DeleteImportGraphDefResults(results);
1838 }
1839
1840 // While loop functions -------------------------------------------------------
1841
1842 namespace {
1843
1844 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1845
1846 // Creates a placeholder representing an input to the cond or body graph.
1847 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1848 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1849 TF_Output* input, TF_Status* status) {
1850 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1851 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1852 // TODO(skyewm): set placeholder shape
1853 TF_Operation* oper = TF_FinishOperation(desc, status);
1854 if (!status->status.ok()) return false;
1855 *input = {oper, 0};
1856 return true;
1857 }
1858
1859 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1860 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
1861 // will be prepended to copied node names. `control_deps` are nodes in
1862 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1863 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1864 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1865 tensorflow::Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1866 tensorflow::ShapeRefiner* dst_refiner,
1867 const TF_Output* src_inputs,
1868 const std::vector<tensorflow::Output>& dst_inputs,
1869 const string& prefix,
1870 const std::vector<tensorflow::Operation>& control_deps,
1871 const TF_Output* nodes_to_return, int nreturn_nodes,
1872 std::vector<tensorflow::Output>* return_nodes) {
1873 DCHECK(return_nodes != nullptr);
1874 GraphDef gdef;
1875 src_graph->ToGraphDef(&gdef);
1876
1877 tensorflow::ImportGraphDefOptions opts;
1878 opts.prefix = prefix;
1879
1880 for (int i = 0; i < dst_inputs.size(); ++i) {
1881 opts.input_map[ToTensorId(src_inputs[i])] =
1882 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1883 }
1884 opts.skip_mapped_nodes = true;
1885
1886 for (const tensorflow::Operation& op : control_deps) {
1887 opts.control_dependencies.push_back(op.node()->name());
1888 }
1889
1890 for (int i = 0; i < nreturn_nodes; ++i) {
1891 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1892 }
1893
1894 // TODO(skyewm): change to OutputTensor
1895 tensorflow::ImportGraphDefResults results;
1896 TF_RETURN_IF_ERROR(
1897 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1898
1899 for (const auto& pair : results.return_tensors) {
1900 return_nodes->emplace_back(pair.first, pair.second);
1901 }
1902 return ::tensorflow::OkStatus();
1903 }
1904
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1905 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1906 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1907 params.cond_graph->parent == nullptr ||
1908 params.cond_graph->parent != params.body_graph->parent ||
1909 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1910 params.ninputs <= 0 || params.cond_inputs == nullptr ||
1911 params.body_inputs == nullptr || params.body_outputs == nullptr) {
1912 s->status = InvalidArgument(
1913 "TF_WhileParams must be created by successful TF_NewWhile() call");
1914 return false;
1915 }
1916 return true;
1917 }
1918
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1919 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1920 if (params.cond_output.oper == nullptr) {
1921 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1922 return false;
1923 }
1924 for (int i = 0; i < params.ninputs; ++i) {
1925 if (params.body_outputs[i].oper == nullptr) {
1926 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1927 "field isn't set");
1928 return false;
1929 }
1930 }
1931 if (params.name == nullptr) {
1932 s->status = InvalidArgument("TF_WhileParams `name` field is null");
1933 return false;
1934 }
1935 return true;
1936 }
1937
1938 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1939
FreeWhileResources(const TF_WhileParams * params)1940 void FreeWhileResources(const TF_WhileParams* params) {
1941 TF_DeleteGraph(params->cond_graph);
1942 TF_DeleteGraph(params->body_graph);
1943 delete[] params->cond_inputs;
1944 delete[] params->body_inputs;
1945 delete[] params->body_outputs;
1946 }
1947
EmptyWhileParams()1948 TF_WhileParams EmptyWhileParams() {
1949 return {0, nullptr, nullptr, {nullptr, 0},
1950 nullptr, nullptr, nullptr, nullptr};
1951 }
1952
1953 } // namespace
1954
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1955 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1956 TF_Status* status) {
1957 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1958 status->status = tensorflow::errors::Unimplemented(
1959 "Creating while loops is not supported on mobile. File a bug at "
1960 "https://github.com/tensorflow/tensorflow/issues if this feature is "
1961 "important to you");
1962 return EmptyWhileParams();
1963 #else
1964 if (ninputs == 0) {
1965 status->status =
1966 InvalidArgument("TF_NewWhile() must be passed at least one input");
1967 return EmptyWhileParams();
1968 }
1969
1970 TF_Graph* cond_graph = TF_NewGraph();
1971 TF_Graph* body_graph = TF_NewGraph();
1972 cond_graph->parent = g;
1973 cond_graph->parent_inputs = inputs;
1974 body_graph->parent = g;
1975 body_graph->parent_inputs = inputs;
1976
1977 TF_Output* cond_inputs = new TF_Output[ninputs];
1978 TF_Output cond_output = {nullptr, -1};
1979 TF_Output* body_inputs = new TF_Output[ninputs];
1980 TF_Output* body_outputs = new TF_Output[ninputs];
1981 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1982 const char* name = nullptr;
1983
1984 for (int i = 0; i < ninputs; ++i) {
1985 // TODO(skyewm): prefix names with underscore (requires some plumbing)
1986 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1987 &cond_inputs[i], status)) {
1988 break;
1989 }
1990 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1991 &body_inputs[i], status)) {
1992 break;
1993 }
1994 }
1995
1996 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
1997 body_graph, body_inputs, body_outputs, name};
1998
1999 if (!status->status.ok()) {
2000 FreeWhileResources(¶ms);
2001 return EmptyWhileParams();
2002 }
2003 return params;
2004 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2005 }
2006
2007 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2008 namespace {
2009
2010 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2011 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
2012 TF_Output* outputs) {
2013 if (!ValidateInputWhileParams(*params, status)) return;
2014
2015 TF_Graph* parent = params->cond_graph->parent;
2016 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
2017 int num_loop_vars = params->ninputs;
2018
2019 mutex_lock l(parent->mu);
2020
2021 // 'cond_fn' copies the cond graph into the parent graph.
2022 tensorflow::ops::CondGraphBuilderFn cond_fn =
2023 [params, parent](const tensorflow::Scope& scope,
2024 const std::vector<tensorflow::Output>& inputs,
2025 tensorflow::Output* output) {
2026 DCHECK_EQ(scope.graph(), &parent->graph);
2027 std::vector<tensorflow::Output> cond_output;
2028 TF_RETURN_IF_ERROR(CopyGraph(
2029 ¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
2030 params->cond_inputs, inputs, scope.impl()->name(),
2031 scope.impl()->control_deps(), ¶ms->cond_output,
2032 /* nreturn_nodes */ 1, &cond_output));
2033 *output = cond_output[0];
2034 return ::tensorflow::OkStatus();
2035 };
2036
2037 // 'body_fn' copies the body graph into the parent graph.
2038 tensorflow::ops::BodyGraphBuilderFn body_fn =
2039 [params, parent, num_loop_vars](
2040 const tensorflow::Scope& scope,
2041 const std::vector<tensorflow::Output>& inputs,
2042 std::vector<tensorflow::Output>* outputs) {
2043 DCHECK_EQ(scope.graph(), &parent->graph);
2044 TF_RETURN_IF_ERROR(
2045 CopyGraph(¶ms->body_graph->graph, &parent->graph,
2046 &parent->refiner, params->body_inputs, inputs,
2047 scope.impl()->name(), scope.impl()->control_deps(),
2048 params->body_outputs, num_loop_vars, outputs));
2049 return ::tensorflow::OkStatus();
2050 };
2051
2052 // Create the while loop using an internal scope.
2053 tensorflow::Scope scope =
2054 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2055 .NewSubScope(params->name);
2056
2057 const int first_new_node_id = parent->graph.num_node_ids();
2058
2059 tensorflow::OutputList loop_outputs;
2060 status->status = tensorflow::ops::BuildWhileLoop(
2061 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2062 body_fn, params->name, &loop_outputs);
2063
2064 // Update name_map with newly-created ops.
2065 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2066 // a bad status. Once we fix this, we may want to return early instead of
2067 // executing the following code.
2068 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2069 Node* new_node = parent->graph.FindNodeId(i);
2070 if (new_node == nullptr) continue;
2071 parent->name_map[new_node->name()] = new_node;
2072 }
2073
2074 // Populate 'outputs'.
2075 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2076 for (int i = 0; i < loop_outputs.size(); ++i) {
2077 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2078 }
2079 }
2080
2081 } // namespace
2082 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2083
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2084 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2085 TF_Output* outputs) {
2086 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2087 status->status = tensorflow::errors::Unimplemented(
2088 "Creating while loops is not supported on mobile. File a bug at "
2089 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2090 "important to you");
2091 #else
2092 // If it appears the caller created or modified `params`, don't free resources
2093 if (!ValidateConstWhileParams(*params, status)) return;
2094 TF_FinishWhileHelper(params, status, outputs);
2095 FreeWhileResources(params);
2096 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2097 }
2098
TF_AbortWhile(const TF_WhileParams * params)2099 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2100
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2101 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2102 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2103 TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2104 }
2105
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2106 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2107 int ny, TF_Output* x, int nx, TF_Output* dx,
2108 TF_Status* status, TF_Output* dy) {
2109 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2110 status->status = tensorflow::errors::Unimplemented(
2111 "Adding gradients is not supported on mobile. File a bug at "
2112 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2113 "important to you");
2114 #else
2115 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2116 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2117 std::vector<tensorflow::Output> dy_arg;
2118
2119 {
2120 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2121 mutex_lock graph_lock(g->mu);
2122
2123 const int first_new_node_id = g->graph.num_node_ids();
2124
2125 string prefix_cmp;
2126 const char* child_scope_name;
2127 if (prefix == nullptr) {
2128 child_scope_name = "gradients";
2129 } else {
2130 prefix_cmp = string(prefix) + "/";
2131 // The operation should fail if the provided name prefix has already been
2132 // used in this graph
2133 for (const auto& pair : g->name_map) {
2134 const string& name = pair.first;
2135 if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2136 status->status = InvalidArgument(
2137 "prefix [", prefix,
2138 "] conflicts with existing node in the graph named [", name, "]");
2139 return;
2140 }
2141 }
2142 child_scope_name = prefix;
2143 }
2144 tensorflow::Scope scope =
2145 NewInternalScope(&g->graph, &status->status, &g->refiner)
2146 .NewSubScope(child_scope_name);
2147
2148 if (dx != nullptr) {
2149 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2150 status->status =
2151 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2152 } else {
2153 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2154 }
2155
2156 // Update g->name_map with the name_map from the scope, which will contain
2157 // the new gradient ops.
2158 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2159 Node* n = g->graph.FindNodeId(i);
2160 if (n == nullptr) continue;
2161
2162 // Adding the gradients to the graph can alter the prefix to prevent
2163 // name collisions only if this prefix has not been provided explicitly
2164 // by the user. If it was provided, assert that it remained intact.
2165 if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2166 status->status = tensorflow::errors::Internal(
2167 "BUG: The gradients prefix have been unexpectedly altered when "
2168 "adding the nodes to the graph. This is a bug. Please file an "
2169 "issue at https://github.com/tensorflow/tensorflow/issues.");
2170 return;
2171 }
2172 // We have a convoluted scheme here: Using the C++ graph construction API
2173 // to add potentially many nodes to the graph without running the checks
2174 // (such as uniqueness of the names of nodes) we run with other functions
2175 // that add a node to the graph (like TF_FinishOperation).
2176 if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2177 status->status = tensorflow::errors::Internal(
2178 "BUG: The API allowed construction of a graph with duplicate node "
2179 "names (",
2180 n->name(),
2181 "). This is a bug. Please file an issue at "
2182 "https://github.com/tensorflow/tensorflow/issues.");
2183 }
2184 }
2185 }
2186
2187 // Unpack the results from grad_outputs_arg.
2188 TFOutputsFromOutputs(dy_arg, dy);
2189 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2190 }
2191
2192 // TF_Session functions ----------------------------------------------
2193
TF_Session(tensorflow::Session * s,TF_Graph * g)2194 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2195 : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2196
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2197 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2198 TF_Status* status) {
2199 Session* session;
2200 status->status = NewSession(opt->options, &session);
2201 if (status->status.ok()) {
2202 TF_Session* new_session = new TF_Session(session, graph);
2203 if (graph != nullptr) {
2204 mutex_lock l(graph->mu);
2205 graph->sessions[new_session] = "";
2206 }
2207 return new_session;
2208 } else {
2209 LOG(ERROR) << status->status;
2210 DCHECK_EQ(nullptr, session);
2211 return nullptr;
2212 }
2213 }
2214
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2215 TF_Session* TF_LoadSessionFromSavedModel(
2216 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2217 const char* export_dir, const char* const* tags, int tags_len,
2218 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2219 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2220 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2221 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2222 status->status = tensorflow::errors::Unimplemented(
2223 "Loading a SavedModel is not supported on mobile. File a bug at "
2224 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2225 "important to you");
2226 return nullptr;
2227 #else
2228 mutex_lock l(graph->mu);
2229 if (!graph->name_map.empty()) {
2230 status->status = InvalidArgument("Graph is non-empty.");
2231 return nullptr;
2232 }
2233
2234 RunOptions run_options_proto;
2235 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2236 run_options->data, run_options->length)) {
2237 status->status = InvalidArgument("Unparseable RunOptions proto");
2238 return nullptr;
2239 }
2240
2241 std::unordered_set<string> tag_set;
2242 for (int i = 0; i < tags_len; i++) {
2243 tag_set.insert(string(tags[i]));
2244 }
2245
2246 tensorflow::SavedModelBundle bundle;
2247 status->status =
2248 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2249 export_dir, tag_set, &bundle);
2250 if (!status->status.ok()) return nullptr;
2251
2252 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2253 // extends using GraphDefs. The Graph instance is different, but equivalent
2254 // to the one used to create the session.
2255 //
2256 // TODO(jhseu): When Session is modified to take Graphs instead of
2257 // GraphDefs, return the Graph generated in LoadSavedModel().
2258 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2259 TF_ImportGraphDefResults results;
2260 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2261 import_opts, &results, status);
2262 TF_DeleteImportGraphDefOptions(import_opts);
2263 if (!status->status.ok()) return nullptr;
2264
2265 if (meta_graph_def != nullptr) {
2266 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2267 if (!status->status.ok()) return nullptr;
2268 }
2269
2270 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2271
2272 graph->sessions[session] = "";
2273 session->last_num_graph_nodes = graph->graph.num_node_ids();
2274 return session;
2275 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2276 }
2277
TF_CloseSession(TF_Session * s,TF_Status * status)2278 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2279 status->status = s->session->Close();
2280 }
2281
TF_DeleteSession(TF_Session * s,TF_Status * status)2282 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2283 status->status = ::tensorflow::OkStatus();
2284 if (s == nullptr) return;
2285 TF_Graph* const graph = s->graph;
2286 if (graph != nullptr) {
2287 graph->mu.lock();
2288 graph->sessions.erase(s);
2289 const bool del = graph->delete_requested && graph->sessions.empty();
2290 graph->mu.unlock();
2291 if (del) delete graph;
2292 }
2293 delete s->session;
2294 delete s;
2295 }
2296
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2297 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2298 const TF_Output* inputs, TF_Tensor* const* input_values,
2299 int ninputs, const TF_Output* outputs,
2300 TF_Tensor** output_values, int noutputs,
2301 const TF_Operation* const* target_opers, int ntargets,
2302 TF_Buffer* run_metadata, TF_Status* status) {
2303 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2304 // directly, instead of requiring us to serialize to a GraphDef and
2305 // call Session::Extend().
2306 if (session->extend_before_run &&
2307 !ExtendSessionGraphHelper(session, status)) {
2308 return;
2309 }
2310
2311 TF_Run_Setup(noutputs, output_values, status);
2312
2313 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2314 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2315 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2316 for (int i = 0; i < ninputs; ++i) {
2317 input_pairs[i].first = OutputName(inputs[i]);
2318 }
2319
2320 // Convert from TF_Output to string names.
2321 std::vector<string> output_names(noutputs);
2322 for (int i = 0; i < noutputs; ++i) {
2323 output_names[i] = OutputName(outputs[i]);
2324 }
2325
2326 // Convert from TF_Operation* to string names.
2327 std::vector<string> target_names(ntargets);
2328 for (int i = 0; i < ntargets; ++i) {
2329 target_names[i] = target_opers[i]->node.name();
2330 }
2331
2332 // Actually run.
2333 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2334 output_names, output_values, target_names, run_metadata,
2335 status);
2336 }
2337
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2338 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2339 int ninputs, const TF_Output* outputs, int noutputs,
2340 const TF_Operation* const* target_opers, int ntargets,
2341 const char** handle, TF_Status* status) {
2342 *handle = nullptr;
2343
2344 if (session->extend_before_run &&
2345 !ExtendSessionGraphHelper(session, status)) {
2346 return;
2347 }
2348
2349 std::vector<string> input_names(ninputs);
2350 for (int i = 0; i < ninputs; ++i) {
2351 input_names[i] = OutputName(inputs[i]);
2352 }
2353
2354 std::vector<string> output_names(noutputs);
2355 for (int i = 0; i < noutputs; ++i) {
2356 output_names[i] = OutputName(outputs[i]);
2357 }
2358
2359 std::vector<string> target_names(ntargets);
2360 for (int i = 0; i < ntargets; ++i) {
2361 target_names[i] = target_opers[i]->node.name();
2362 }
2363
2364 string new_handle;
2365 status->status = session->session->PRunSetup(input_names, output_names,
2366 target_names, &new_handle);
2367 if (status->status.ok()) {
2368 char* buf = new char[new_handle.size() + 1];
2369 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2370 *handle = buf;
2371 }
2372 }
2373
TF_DeletePRunHandle(const char * handle)2374 void TF_DeletePRunHandle(const char* handle) {
2375 delete[] handle;
2376 // TODO(suharshs): Free up any resources held by the partial run state.
2377 }
2378
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2379 void TF_SessionPRun(TF_Session* session, const char* handle,
2380 const TF_Output* inputs, TF_Tensor* const* input_values,
2381 int ninputs, const TF_Output* outputs,
2382 TF_Tensor** output_values, int noutputs,
2383 const TF_Operation* const* target_opers, int ntargets,
2384 TF_Status* status) {
2385 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2386 // directly, instead of requiring us to serialize to a GraphDef and
2387 // call Session::Extend().
2388 if (session->extend_before_run &&
2389 !ExtendSessionGraphHelper(session, status)) {
2390 return;
2391 }
2392
2393 TF_Run_Setup(noutputs, output_values, status);
2394
2395 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2396 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2397 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2398 for (int i = 0; i < ninputs; ++i) {
2399 input_pairs[i].first = OutputName(inputs[i]);
2400 }
2401
2402 // Convert from TF_Output to string names.
2403 std::vector<string> output_names(noutputs);
2404 for (int i = 0; i < noutputs; ++i) {
2405 output_names[i] = OutputName(outputs[i]);
2406 }
2407
2408 // Convert from TF_Operation* to string names.
2409 std::vector<string> target_names(ntargets);
2410 for (int i = 0; i < ntargets; ++i) {
2411 target_names[i] = target_opers[i]->node.name();
2412 }
2413
2414 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2415 output_values, target_names, nullptr, status);
2416 }
2417
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2418 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2419 TF_Tensor** result, TF_Status* status) {
2420 *result = nullptr;
2421 mutex_lock l(graph->mu);
2422 OutputTensor tensor(&output.oper->node, output.index);
2423 bool evaluated;
2424 Tensor result_tensor;
2425 status->status = EvaluateConstantTensor(
2426 tensor, graph->refiner, *graph->graph.op_registry(),
2427 graph->graph.versions().producer(), &evaluated, &result_tensor);
2428 if (evaluated) {
2429 DCHECK(status->status.ok());
2430 *result = TF_TensorFromTensor(result_tensor, &status->status);
2431 if (!status->status.ok()) evaluated = false;
2432 }
2433 return evaluated;
2434 }
2435
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2436 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2437 tensorflow::OpList op_list;
2438 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2439 status->status = InvalidArgument("Unparseable OpList");
2440 return nullptr;
2441 }
2442 status->status = ::tensorflow::OkStatus();
2443 return new TF_ApiDefMap(op_list);
2444 }
2445
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2446 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2447
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2448 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2449 size_t text_len, TF_Status* status) {
2450 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2451 status->status = tensorflow::errors::Unimplemented(
2452 "ApiDefMap is not supported on mobile.");
2453 #else
2454 mutex_lock l(api_def_map->lock);
2455 if (api_def_map->update_docs_called) {
2456 status->status = FailedPrecondition(
2457 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2458 "called.");
2459 return;
2460 }
2461 string api_def_text(text, text_len);
2462 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2463 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2464 }
2465
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2466 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2467 size_t name_len, TF_Status* status) {
2468 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2469 status->status = tensorflow::errors::Unimplemented(
2470 "ApiDefMap is not supported on mobile.");
2471 return nullptr;
2472 #else
2473 mutex_lock l(api_def_map->lock);
2474 if (!api_def_map->update_docs_called) {
2475 api_def_map->api_def_map.UpdateDocs();
2476 api_def_map->update_docs_called = true;
2477 }
2478 string name_str(name, name_len);
2479 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2480 if (api_def == nullptr) {
2481 return nullptr;
2482 }
2483
2484 TF_Buffer* ret = TF_NewBuffer();
2485 status->status = MessageToBuffer(*api_def, ret);
2486 if (!status->status.ok()) {
2487 TF_DeleteBuffer(ret);
2488 return nullptr;
2489 }
2490 return ret;
2491 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2492 }
2493
TF_GetAllRegisteredKernels(TF_Status * status)2494 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2495 tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2496 TF_Buffer* ret = TF_NewBuffer();
2497 status->status = MessageToBuffer(kernel_list, ret);
2498 if (!status->status.ok()) {
2499 TF_DeleteBuffer(ret);
2500 return nullptr;
2501 }
2502 return ret;
2503 }
2504
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2505 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2506 tensorflow::KernelList kernel_list =
2507 tensorflow::GetRegisteredKernelsForOp(name);
2508 TF_Buffer* ret = TF_NewBuffer();
2509 status->status = MessageToBuffer(kernel_list, ret);
2510 if (!status->status.ok()) {
2511 TF_DeleteBuffer(ret);
2512 return nullptr;
2513 }
2514 return ret;
2515 }
2516
TF_UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)2517 void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2518 TF_Status* status) {
2519 using tensorflow::RecordMutation;
2520 mutex_lock l(graph->mu);
2521 tensorflow::shape_inference::InferenceContext* ic =
2522 graph->refiner.GetContext(&new_src.oper->node);
2523
2524 if (ic->num_outputs() <= new_src.index) {
2525 status->status = tensorflow::errors::OutOfRange(
2526 "Cannot update edge. Output index [", new_src.index,
2527 "] is greater than the number of total outputs [", ic->num_outputs(),
2528 "].");
2529 return;
2530 }
2531 tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
2532
2533 tensorflow::shape_inference::InferenceContext* ic_dst =
2534 graph->refiner.GetContext(&dst.oper->node);
2535 if (ic_dst->num_inputs() <= dst.index) {
2536 status->status = tensorflow::errors::OutOfRange(
2537 "Cannot update edge. Input index [", dst.index,
2538 "] is greater than the number of total inputs [", ic_dst->num_inputs(),
2539 "].");
2540 return;
2541 }
2542 if (!ic_dst->MergeInput(dst.index, shape)) {
2543 status->status = tensorflow::errors::InvalidArgument(
2544 "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
2545 " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
2546 return;
2547 }
2548 status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
2549 &dst.oper->node, dst.index);
2550
2551 if (TF_GetCode(status) == TF_OK) {
2552 // This modification only updates the destination node for
2553 // the purposes of running this graph in a session. Thus, we don't
2554 // record the source node as being modified.
2555 RecordMutation(graph, *dst.oper, "updating input tensor");
2556 }
2557 }
2558
2559 // TF_Server functions ----------------------------------------------
2560
2561 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2562 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2563 : target(server->target()), server(std::move(server)) {}
2564 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2565
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2566 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2567 TF_Status* status) {
2568 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2569 status->status = tensorflow::errors::Unimplemented(
2570 "Server functionality is not supported on mobile");
2571 return nullptr;
2572 #else
2573 tensorflow::ServerDef server_def;
2574 if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2575 status->status = InvalidArgument(
2576 "Could not parse provided bytes into a ServerDef protocol buffer");
2577 return nullptr;
2578 }
2579
2580 std::unique_ptr<tensorflow::ServerInterface> out_server;
2581 status->status = tensorflow::NewServer(server_def, &out_server);
2582 if (!status->status.ok()) return nullptr;
2583
2584 return new TF_Server(std::move(out_server));
2585 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2586 }
2587
TF_ServerStart(TF_Server * server,TF_Status * status)2588 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2589 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2590 status->status = tensorflow::errors::Unimplemented(
2591 "Server functionality is not supported on mobile");
2592 #else
2593 status->status = server->server->Start();
2594 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2595 }
2596
TF_ServerStop(TF_Server * server,TF_Status * status)2597 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2598 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2599 status->status = tensorflow::errors::Unimplemented(
2600 "Server functionality is not supported on mobile");
2601 #else
2602 status->status = server->server->Stop();
2603 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2604 }
2605
TF_ServerJoin(TF_Server * server,TF_Status * status)2606 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2607 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2608 status->status = tensorflow::errors::Unimplemented(
2609 "Server functionality is not supported on mobile");
2610 #else
2611 status->status = server->server->Join();
2612 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2613 }
2614
TF_ServerTarget(TF_Server * server)2615 const char* TF_ServerTarget(TF_Server* server) {
2616 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2617 return nullptr;
2618 #else
2619 return server->target.c_str();
2620 #endif
2621 }
2622
TF_DeleteServer(TF_Server * server)2623 void TF_DeleteServer(TF_Server* server) {
2624 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2625 delete server;
2626 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2627 }
2628
TF_RegisterLogListener(void (* listener)(const char *))2629 void TF_RegisterLogListener(void (*listener)(const char*)) {
2630 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2631 tensorflow::logging::RegisterListener(listener);
2632 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2633 }
2634
TF_RegisterFilesystemPlugin(const char * plugin_filename,TF_Status * status)2635 void TF_RegisterFilesystemPlugin(const char* plugin_filename,
2636 TF_Status* status) {
2637 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2638 status->status = tensorflow::errors::Unimplemented(
2639 "FileSystem plugin functionality is not supported on mobile");
2640 #else
2641 status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
2642 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2643 }
2644
2645 } // end extern "C"
2646