1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <memory>
21 #include <vector>
22
23 #include "absl/strings/match.h"
24 // Required for IS_MOBILE_PLATFORM
25 #include "tensorflow/core/platform/platform.h" // NOLINT
26
27 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
28 #include "tensorflow/c/experimental/filesystem/modular_filesystem.h"
29 #include "tensorflow/cc/framework/gradients.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/cc/framework/scope_internal.h"
32 #include "tensorflow/cc/ops/while_loop.h"
33 #include "tensorflow/cc/saved_model/loader.h"
34 #include "tensorflow/core/distributed_runtime/server_lib.h"
35 #include "tensorflow/core/framework/logging.h"
36 #include "tensorflow/core/framework/op_gen_lib.h"
37 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
38 #include "tensorflow/c/c_api_internal.h"
39 #include "tensorflow/c/tf_status_internal.h"
40 #include "tensorflow/c/tf_tensor.h"
41 #include "tensorflow/core/common_runtime/device_mgr.h"
42 #include "tensorflow/core/common_runtime/eval_const_tensor.h"
43 #include "tensorflow/core/common_runtime/graph_constructor.h"
44 #include "tensorflow/core/common_runtime/shape_refiner.h"
45 #include "tensorflow/core/framework/allocation_description.pb.h"
46 #include "tensorflow/core/framework/kernel_def.pb.h"
47 #include "tensorflow/core/framework/log_memory.h"
48 #include "tensorflow/core/framework/node_def_util.h"
49 #include "tensorflow/core/framework/op_kernel.h"
50 #include "tensorflow/core/framework/partial_tensor_shape.h"
51 #include "tensorflow/core/framework/tensor.h"
52 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
53 #include "tensorflow/core/framework/tensor_shape.h"
54 #include "tensorflow/core/framework/tensor_shape.pb.h"
55 #include "tensorflow/core/framework/types.h"
56 #include "tensorflow/core/framework/versions.pb.h"
57 #include "tensorflow/core/graph/graph.h"
58 #include "tensorflow/core/graph/node_builder.h"
59 #include "tensorflow/core/graph/validate.h"
60 #include "tensorflow/core/lib/gtl/array_slice.h"
61 #include "tensorflow/core/platform/coding.h"
62 #include "tensorflow/core/platform/errors.h"
63 #include "tensorflow/core/platform/mem.h"
64 #include "tensorflow/core/platform/mutex.h"
65 #include "tensorflow/core/platform/protobuf.h"
66 #include "tensorflow/core/platform/status.h"
67 #include "tensorflow/core/platform/str_util.h"
68 #include "tensorflow/core/platform/strcat.h"
69 #include "tensorflow/core/platform/stringpiece.h"
70 #include "tensorflow/core/platform/thread_annotations.h"
71 #include "tensorflow/core/platform/types.h"
72 #include "tensorflow/core/public/session.h"
73 #include "tensorflow/core/public/version.h"
74
75 // The implementation below is at the top level instead of the
76 // brain namespace because we are defining 'extern "C"' functions.
77 using tensorflow::AllocationDescription;
78 using tensorflow::DataType;
79 using tensorflow::ExtendSessionGraphHelper;
80 using tensorflow::Graph;
81 using tensorflow::GraphDef;
82 using tensorflow::mutex_lock;
83 using tensorflow::NameRangeMap;
84 using tensorflow::NameRangesForNode;
85 using tensorflow::NewSession;
86 using tensorflow::Node;
87 using tensorflow::NodeBuilder;
88 using tensorflow::NodeDef;
89 using tensorflow::OpDef;
90 using tensorflow::OpRegistry;
91 using tensorflow::OutputTensor;
92 using tensorflow::PartialTensorShape;
93 using tensorflow::RunMetadata;
94 using tensorflow::RunOptions;
95 using tensorflow::Session;
96 using tensorflow::Status;
97 using tensorflow::string;
98 using tensorflow::Tensor;
99 using tensorflow::TensorBuffer;
100 using tensorflow::TensorId;
101 using tensorflow::TensorShape;
102 using tensorflow::TensorShapeProto;
103 using tensorflow::VersionDef;
104 using tensorflow::errors::FailedPrecondition;
105 using tensorflow::errors::InvalidArgument;
106 using tensorflow::gtl::ArraySlice;
107 using tensorflow::strings::StrCat;
108
109 extern "C" {
110
111 // --------------------------------------------------------------------------
TF_Version()112 const char* TF_Version() { return TF_VERSION_STRING; }
113
114 // --------------------------------------------------------------------------
115
116 // --------------------------------------------------------------------------
TF_NewSessionOptions()117 TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
TF_DeleteSessionOptions(TF_SessionOptions * opt)118 void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
119
TF_SetTarget(TF_SessionOptions * options,const char * target)120 void TF_SetTarget(TF_SessionOptions* options, const char* target) {
121 options->options.target = target;
122 }
123
TF_SetConfig(TF_SessionOptions * options,const void * proto,size_t proto_len,TF_Status * status)124 void TF_SetConfig(TF_SessionOptions* options, const void* proto,
125 size_t proto_len, TF_Status* status) {
126 if (!options->options.config.ParseFromArray(proto, proto_len)) {
127 status->status = InvalidArgument("Unparseable ConfigProto");
128 }
129 }
130 // --------------------------------------------------------------------------
TF_NewBuffer()131 TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; }
132
TF_NewBufferFromString(const void * proto,size_t proto_len)133 TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) {
134 void* copy = tensorflow::port::Malloc(proto_len);
135 memcpy(copy, proto, proto_len);
136
137 TF_Buffer* buf = new TF_Buffer;
138 buf->data = copy;
139 buf->length = proto_len;
140 buf->data_deallocator = [](void* data, size_t length) {
141 tensorflow::port::Free(data);
142 };
143 return buf;
144 }
145
TF_DeleteBuffer(TF_Buffer * buffer)146 void TF_DeleteBuffer(TF_Buffer* buffer) {
147 if (buffer == nullptr) return;
148 if (buffer->data_deallocator != nullptr) {
149 (*buffer->data_deallocator)(const_cast<void*>(buffer->data),
150 buffer->length);
151 }
152 delete buffer;
153 }
154
TF_GetBuffer(TF_Buffer * buffer)155 TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
156
157 // --------------------------------------------------------------------------
158
TF_NewDeprecatedSession(const TF_SessionOptions * opt,TF_Status * status)159 TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
160 TF_Status* status) {
161 Session* session;
162 status->status = NewSession(opt->options, &session);
163 if (status->status.ok()) {
164 return new TF_DeprecatedSession({session});
165 } else {
166 DCHECK_EQ(nullptr, session);
167 return nullptr;
168 }
169 }
170
TF_CloseDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)171 void TF_CloseDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
172 status->status = s->session->Close();
173 }
174
TF_DeleteDeprecatedSession(TF_DeprecatedSession * s,TF_Status * status)175 void TF_DeleteDeprecatedSession(TF_DeprecatedSession* s, TF_Status* status) {
176 status->status = Status::OK();
177 if (s == nullptr) return;
178 delete s->session;
179 delete s;
180 }
181
TF_ExtendGraph(TF_DeprecatedSession * s,const void * proto,size_t proto_len,TF_Status * status)182 void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
183 size_t proto_len, TF_Status* status) {
184 GraphDef g;
185 if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
186 status->status = InvalidArgument("Invalid GraphDef");
187 return;
188 }
189 status->status = s->session->Extend(g);
190 }
191
192 } // end extern "C"
193
194 // Reset helper for converting character arrays to string vectors.
TF_Reset_Helper(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)195 static void TF_Reset_Helper(const TF_SessionOptions* opt,
196 const char** containers, int ncontainers,
197 TF_Status* status) {
198 std::vector<string> container_names(ncontainers);
199 for (int i = 0; i < ncontainers; ++i) {
200 container_names[i] = containers[i];
201 }
202
203 status->status = Reset(opt->options, container_names);
204 }
205
206 extern "C" {
207
TF_Reset(const TF_SessionOptions * opt,const char ** containers,int ncontainers,TF_Status * status)208 void TF_Reset(const TF_SessionOptions* opt, const char** containers,
209 int ncontainers, TF_Status* status) {
210 TF_Reset_Helper(opt, containers, ncontainers, status);
211 }
212
213 } // end extern "C"
214
215 namespace tensorflow {
216
MessageToBuffer(const tensorflow::protobuf::MessageLite & in,TF_Buffer * out)217 Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
218 TF_Buffer* out) {
219 if (out->data != nullptr) {
220 return InvalidArgument("Passing non-empty TF_Buffer is invalid.");
221 }
222 const size_t proto_size = in.ByteSizeLong();
223 void* buf = port::Malloc(proto_size);
224 if (buf == nullptr) {
225 return tensorflow::errors::ResourceExhausted(
226 "Failed to allocate memory to serialize message of type '",
227 in.GetTypeName(), "' and size ", proto_size);
228 }
229 if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) {
230 port::Free(buf);
231 return InvalidArgument("Unable to serialize ", in.GetTypeName(),
232 " protocol buffer, perhaps the serialized size (",
233 proto_size, " bytes) is too large?");
234 }
235 out->data = buf;
236 out->length = proto_size;
237 out->data_deallocator = [](void* data, size_t length) { port::Free(data); };
238 return Status::OK();
239 }
240
RecordMutation(TF_Graph * graph,const TF_Operation & op,const char * mutation_type)241 void RecordMutation(TF_Graph* graph, const TF_Operation& op,
242 const char* mutation_type) {
243 // If any session has already run this node_id, mark this session as
244 // unrunnable.
245 for (auto it : graph->sessions) {
246 mutex_lock session_lock(it.first->mu);
247 if (it.first->last_num_graph_nodes > op.node.id()) {
248 it.second = strings::StrCat(
249 "Operation '", op.node.DebugString(), "' was changed by ",
250 mutation_type,
251 " after it was run by a session. This mutation will have no effect, "
252 "and will trigger an error in the future. Either don't modify "
253 "nodes after running them or create a new session.");
254 }
255 }
256 }
257
258 namespace {
259
260 // Helper method that creates a shape handle for a shape described by dims.
ShapeHandleFromDims(tensorflow::shape_inference::InferenceContext * ic,int num_dims,const int64_t * dims)261 tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims(
262 tensorflow::shape_inference::InferenceContext* ic, int num_dims,
263 const int64_t* dims) {
264 if (num_dims != -1) {
265 std::vector<tensorflow::shape_inference::DimensionHandle> dim_vec;
266 dim_vec.reserve(num_dims);
267 for (int i = 0; i < num_dims; ++i) {
268 dim_vec.push_back(ic->MakeDim(dims[i]));
269 }
270 return ic->MakeShape(dim_vec);
271 } else {
272 return ic->UnknownShape();
273 }
274 }
275
276 } // namespace
277
TF_GraphSetOutputHandleShapesAndTypes(TF_Graph * graph,TF_Output output,int num_shapes_and_types,const int64_t ** shapes,const int * ranks,const TF_DataType * types,TF_Status * status)278 void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
279 int num_shapes_and_types,
280 const int64_t** shapes,
281 const int* ranks,
282 const TF_DataType* types,
283 TF_Status* status) {
284 Node* node = &output.oper->node;
285
286 mutex_lock l(graph->mu);
287 tensorflow::shape_inference::InferenceContext* ic =
288 graph->refiner.GetContext(node);
289 if (ic == nullptr) {
290 status->status =
291 InvalidArgument("Node ", node->name(), " was not found in the graph");
292 return;
293 }
294
295 auto shape_and_type_vec =
296 std::vector<tensorflow::shape_inference::ShapeAndType>(
297 num_shapes_and_types);
298 for (int i = 0; i < num_shapes_and_types; ++i) {
299 tensorflow::shape_inference::ShapeHandle shape_handle =
300 ShapeHandleFromDims(ic, ranks[i], shapes[i]);
301 shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType(
302 shape_handle, static_cast<DataType>(types[i]));
303 }
304
305 ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec);
306 }
307
308 // Helpers for loading a TensorFlow plugin (a .so file).
309 Status LoadDynamicLibrary(const char* library_filename, void** result,
310 const void** buf, size_t* len);
311
312 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
313 // directly, instead of requiring us to serialize to a GraphDef and
314 // call Session::Extend().
ExtendSessionGraphHelper(TF_Session * session,TF_Status * status)315 bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
316 if (session->graph != nullptr) {
317 // Take the graph lock before the session lock to avoid deadlock. This is
318 // safe since session->graph does not change.
319 session->graph->mu.lock();
320 mutex_lock session_lock(session->mu);
321 const Graph& graph = session->graph->graph;
322
323 const string& mutation_warning = session->graph->sessions[session];
324 if (!mutation_warning.empty()) {
325 // TODO(b/74949947): turn this back into an error status
326 LOG(WARNING) << mutation_warning;
327 session->graph->sessions[session].clear();
328 }
329
330 const auto num_nodes = graph.num_node_ids();
331 if (session->last_num_graph_nodes < num_nodes) {
332 // TODO(nolivia): check this on a subset of the graph instead of all of
333 // it.
334 status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
335 if (!status->status.ok()) {
336 session->graph->mu.unlock();
337 return false;
338 }
339
340 GraphDef graph_def;
341 *graph_def.mutable_versions() = graph.versions();
342 // Fill graph_def with nodes with ids in the range
343 // [session->last_num_graph_nodes, num_nodes), that is the nodes
344 // added since the last TF_SessionRun() call.
345 for (auto id = session->last_num_graph_nodes; id < num_nodes; ++id) {
346 Node* const node = graph.FindNodeId(id);
347 if (node != nullptr && node->IsOp()) {
348 NodeDef* const node_def = graph_def.add_node();
349 *node_def = node->def();
350 }
351 }
352 *graph_def.mutable_library() = graph.flib_def().ToProto();
353 session->graph->mu.unlock();
354 status->status = session->session->Extend(std::move(graph_def));
355 if (!status->status.ok()) {
356 // Contract is we always delete input_values[i].
357 return false;
358 }
359 // Note: session->session is not modified if Extend() fails, so
360 // we only set last_num_graph_nodes if it succeeds.
361 session->last_num_graph_nodes = num_nodes;
362 } else {
363 session->graph->mu.unlock();
364 }
365 }
366 return true;
367 }
368
369 } // namespace tensorflow
370
TF_Run_Setup(int noutputs,TF_Tensor ** c_outputs,TF_Status * status)371 static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs,
372 TF_Status* status) {
373 status->status = Status::OK();
374 for (int i = 0; i < noutputs; ++i) {
375 c_outputs[i] = nullptr;
376 }
377 }
378
TF_Run_Inputs(TF_Tensor * const * c_inputs,std::vector<std::pair<string,Tensor>> * input_pairs,TF_Status * status)379 static bool TF_Run_Inputs(TF_Tensor* const* c_inputs,
380 std::vector<std::pair<string, Tensor>>* input_pairs,
381 TF_Status* status) {
382 const int ninputs = input_pairs->size();
383 for (int i = 0; i < ninputs; ++i) {
384 status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second);
385 if (!status->status.ok()) return false;
386 }
387 return true;
388 }
389
390 // Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
391 // result in a zero-sized tensor.
EmptyTensor(TF_DataType dtype,const tensorflow::TensorShape & shape)392 static TF_Tensor* EmptyTensor(TF_DataType dtype,
393 const tensorflow::TensorShape& shape) {
394 static char empty;
395 tensorflow::int64 nelems = 1;
396 std::vector<tensorflow::int64> dims;
397 for (int i = 0; i < shape.dims(); ++i) {
398 dims.push_back(shape.dim_size(i));
399 nelems *= shape.dim_size(i);
400 }
401 CHECK_EQ(nelems, 0);
402 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
403 "64-bit int types should match in size");
404 return TF_NewTensor(
405 dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
406 reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
407 }
408
TF_Run_Helper(Session * session,const char * handle,const TF_Buffer * run_options,const std::vector<std::pair<string,Tensor>> & input_pairs,const std::vector<string> & output_tensor_names,TF_Tensor ** c_outputs,const std::vector<string> & target_oper_names,TF_Buffer * run_metadata,TF_Status * status)409 static void TF_Run_Helper(
410 Session* session, const char* handle, const TF_Buffer* run_options,
411 // Input tensors
412 const std::vector<std::pair<string, Tensor>>& input_pairs,
413 // Output tensors
414 const std::vector<string>& output_tensor_names, TF_Tensor** c_outputs,
415 // Target nodes
416 const std::vector<string>& target_oper_names, TF_Buffer* run_metadata,
417 TF_Status* status) {
418 const int noutputs = output_tensor_names.size();
419 std::vector<Tensor> outputs(noutputs);
420 Status result;
421
422 if (handle == nullptr) {
423 RunOptions run_options_proto;
424 if (run_options != nullptr && !run_options_proto.ParseFromArray(
425 run_options->data, run_options->length)) {
426 status->status = InvalidArgument("Unparseable RunOptions proto");
427 return;
428 }
429 if (run_metadata != nullptr && run_metadata->data != nullptr) {
430 status->status =
431 InvalidArgument("Passing non-empty run_metadata is invalid.");
432 return;
433 }
434
435 RunMetadata run_metadata_proto;
436 result = session->Run(run_options_proto, input_pairs, output_tensor_names,
437 target_oper_names, &outputs, &run_metadata_proto);
438
439 // Serialize back to upstream client, who now owns the new buffer
440 if (run_metadata != nullptr) {
441 status->status = MessageToBuffer(run_metadata_proto, run_metadata);
442 if (!status->status.ok()) return;
443 }
444 } else {
445 // NOTE(zongheng): PRun does not support RunOptions yet.
446 result = session->PRun(handle, input_pairs, output_tensor_names, &outputs);
447 }
448 if (!result.ok()) {
449 status->status = result;
450 return;
451 }
452
453 // Store results in c_outputs[]
454 for (int i = 0; i < noutputs; ++i) {
455 const Tensor& src = outputs[i];
456 if (!src.IsInitialized() || src.NumElements() == 0) {
457 c_outputs[i] =
458 EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
459 continue;
460 }
461 c_outputs[i] = TF_TensorFromTensor(src, &status->status);
462 if (!status->status.ok()) return;
463 }
464 }
465
466 extern "C" {
467
TF_Run(TF_DeprecatedSession * s,const TF_Buffer * run_options,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Buffer * run_metadata,TF_Status * status)468 void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options,
469 // Input tensors
470 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
471 // Output tensors
472 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
473 // Target nodes
474 const char** c_target_oper_names, int ntargets,
475 TF_Buffer* run_metadata, TF_Status* status) {
476 TF_Run_Setup(noutputs, c_outputs, status);
477 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
478 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
479 for (int i = 0; i < ninputs; ++i) {
480 input_pairs[i].first = c_input_names[i];
481 }
482 std::vector<string> output_names(noutputs);
483 for (int i = 0; i < noutputs; ++i) {
484 output_names[i] = c_output_names[i];
485 }
486 std::vector<string> target_oper_names(ntargets);
487 for (int i = 0; i < ntargets; ++i) {
488 target_oper_names[i] = c_target_oper_names[i];
489 }
490 TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
491 c_outputs, target_oper_names, run_metadata, status);
492 }
493
TF_PRunSetup(TF_DeprecatedSession * s,const char ** c_input_names,int ninputs,const char ** c_output_names,int noutputs,const char ** c_target_oper_names,int ntargets,const char ** handle,TF_Status * status)494 void TF_PRunSetup(TF_DeprecatedSession* s,
495 // Input names
496 const char** c_input_names, int ninputs,
497 // Output names
498 const char** c_output_names, int noutputs,
499 // Target nodes
500 const char** c_target_oper_names, int ntargets,
501 const char** handle, TF_Status* status) {
502 *handle = nullptr;
503
504 std::vector<string> input_names(ninputs);
505 std::vector<string> output_names(noutputs);
506 std::vector<string> target_oper_names(ntargets);
507 for (int i = 0; i < ninputs; ++i) {
508 input_names[i] = c_input_names[i];
509 }
510 for (int i = 0; i < noutputs; ++i) {
511 output_names[i] = c_output_names[i];
512 }
513 for (int i = 0; i < ntargets; ++i) {
514 target_oper_names[i] = c_target_oper_names[i];
515 }
516 string new_handle;
517 status->status = s->session->PRunSetup(input_names, output_names,
518 target_oper_names, &new_handle);
519 if (status->status.ok()) {
520 char* buf = new char[new_handle.size() + 1];
521 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
522 *handle = buf;
523 }
524 }
525
TF_PRun(TF_DeprecatedSession * s,const char * handle,const char ** c_input_names,TF_Tensor ** c_inputs,int ninputs,const char ** c_output_names,TF_Tensor ** c_outputs,int noutputs,const char ** c_target_oper_names,int ntargets,TF_Status * status)526 void TF_PRun(TF_DeprecatedSession* s, const char* handle,
527 // Input tensors
528 const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
529 // Output tensors
530 const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
531 // Target nodes
532 const char** c_target_oper_names, int ntargets,
533 TF_Status* status) {
534 TF_Run_Setup(noutputs, c_outputs, status);
535 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
536 if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return;
537 for (int i = 0; i < ninputs; ++i) {
538 input_pairs[i].first = c_input_names[i];
539 }
540
541 std::vector<string> output_names(noutputs);
542 for (int i = 0; i < noutputs; ++i) {
543 output_names[i] = c_output_names[i];
544 }
545 std::vector<string> target_oper_names(ntargets);
546 for (int i = 0; i < ntargets; ++i) {
547 target_oper_names[i] = c_target_oper_names[i];
548 }
549 TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
550 c_outputs, target_oper_names, nullptr, status);
551 }
552
TF_LoadLibrary(const char * library_filename,TF_Status * status)553 TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
554 TF_Library* lib_handle = new TF_Library;
555 status->status = tensorflow::LoadDynamicLibrary(
556 library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
557 &lib_handle->op_list.length);
558 if (!status->status.ok()) {
559 delete lib_handle;
560 return nullptr;
561 }
562 return lib_handle;
563 }
564
TF_GetOpList(TF_Library * lib_handle)565 TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
566
TF_DeleteLibraryHandle(TF_Library * lib_handle)567 void TF_DeleteLibraryHandle(TF_Library* lib_handle) {
568 if (lib_handle == nullptr) return;
569 tensorflow::port::Free(const_cast<void*>(lib_handle->op_list.data));
570 delete lib_handle;
571 }
572
TF_GetAllOpList()573 TF_Buffer* TF_GetAllOpList() {
574 std::vector<tensorflow::OpDef> op_defs;
575 tensorflow::OpRegistry::Global()->GetRegisteredOps(&op_defs);
576 tensorflow::OpList op_list;
577 for (const auto& op : op_defs) {
578 *(op_list.add_op()) = op;
579 }
580 TF_Buffer* ret = TF_NewBuffer();
581 TF_CHECK_OK(MessageToBuffer(op_list, ret));
582 return ret;
583 }
584
585 // --------------------------------------------------------------------------
586 // ListDevices & SessionListDevices API
587
TF_DeleteDeviceList(TF_DeviceList * list)588 void TF_DeleteDeviceList(TF_DeviceList* list) { delete list; }
589
TF_SessionListDevices(TF_Session * session,TF_Status * status)590 TF_DeviceList* TF_SessionListDevices(TF_Session* session, TF_Status* status) {
591 TF_DeviceList* response = new TF_DeviceList;
592 if (session && session->session)
593 status->status = session->session->ListDevices(&response->response);
594 return response;
595 }
596
TF_DeprecatedSessionListDevices(TF_DeprecatedSession * session,TF_Status * status)597 TF_DeviceList* TF_DeprecatedSessionListDevices(TF_DeprecatedSession* session,
598 TF_Status* status) {
599 TF_DeviceList* response = new TF_DeviceList;
600 if (session && session->session)
601 status->status = session->session->ListDevices(&response->response);
602 return response;
603 }
604
TF_DeviceListCount(const TF_DeviceList * list)605 int TF_DeviceListCount(const TF_DeviceList* list) {
606 return list->response.size();
607 }
608
609 #define TF_DEVICELIST_METHOD(return_type, method_name, accessor, err_val) \
610 return_type method_name(const TF_DeviceList* list, const int index, \
611 TF_Status* status) { \
612 if (list == nullptr) { \
613 status->status = InvalidArgument("list is null!"); \
614 return err_val; \
615 } \
616 if (index < 0 || index >= list->response.size()) { \
617 status->status = InvalidArgument("index out of bounds"); \
618 return err_val; \
619 } \
620 status->status = Status::OK(); \
621 return list->response[index].accessor; \
622 }
623
624 TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
625 TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
626 nullptr);
627 TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
628 TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
629
630 #undef TF_DEVICELIST_METHOD
631
632 } // end extern "C"
633
634 // --------------------------------------------------------------------------
635 // New Graph and Session API
636
637 // Helper functions -----------------------------------------------------------
638
639 namespace {
640
ToOperation(Node * node)641 TF_Operation* ToOperation(Node* node) {
642 return static_cast<TF_Operation*>(static_cast<void*>(node));
643 }
644
OutputName(const TF_Output & output)645 string OutputName(const TF_Output& output) {
646 return StrCat(output.oper->node.name(), ":", output.index);
647 }
648
GetAttrValue(TF_Operation * oper,const char * attr_name,TF_Status * status)649 const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
650 const char* attr_name,
651 TF_Status* status) {
652 const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
653 if (attr == nullptr) {
654 status->status = InvalidArgument("Operation '", oper->node.name(),
655 "' has no attr named '", attr_name, "'.");
656 }
657 return attr;
658 }
659
ToTensorId(const TF_Output & output)660 TensorId ToTensorId(const TF_Output& output) {
661 return TensorId(output.oper->node.name(), output.index);
662 }
663
664 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
OutputsFromTFOutputs(TF_Output * tf_outputs,int n)665 std::vector<tensorflow::Output> OutputsFromTFOutputs(TF_Output* tf_outputs,
666 int n) {
667 std::vector<tensorflow::Output> outputs(n);
668 for (int i = 0; i < n; ++i) {
669 outputs[i] =
670 tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index);
671 }
672 return outputs;
673 }
674
TFOutputsFromOutputs(const std::vector<tensorflow::Output> & outputs,TF_Output * tf_outputs)675 void TFOutputsFromOutputs(const std::vector<tensorflow::Output>& outputs,
676 TF_Output* tf_outputs) {
677 for (int i = 0; i < outputs.size(); i++) {
678 tf_outputs[i].oper = ToOperation(outputs[i].node());
679 tf_outputs[i].index = outputs[i].index();
680 }
681 }
682 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
683
684 } // namespace
685
686 // Shape functions -----------------------------------------------------------
687
TF_GraphSetTensorShape(TF_Graph * graph,TF_Output output,const int64_t * dims,const int num_dims,TF_Status * status)688 void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output,
689 const int64_t* dims, const int num_dims,
690 TF_Status* status) {
691 Node* node = &output.oper->node;
692
693 mutex_lock l(graph->mu);
694 tensorflow::shape_inference::InferenceContext* ic =
695 graph->refiner.GetContext(node);
696 if (ic == nullptr) {
697 status->status =
698 InvalidArgument("Node ", node->name(), " was not found in the graph");
699 return;
700 }
701 tensorflow::shape_inference::ShapeHandle new_shape =
702 tensorflow::ShapeHandleFromDims(ic, num_dims, dims);
703 status->status = graph->refiner.SetShape(node, output.index, new_shape);
704 }
705
TF_GraphGetTensorNumDims(TF_Graph * graph,TF_Output output,TF_Status * status)706 int TF_GraphGetTensorNumDims(TF_Graph* graph, TF_Output output,
707 TF_Status* status) {
708 Node* node = &output.oper->node;
709
710 mutex_lock l(graph->mu);
711 tensorflow::shape_inference::InferenceContext* ic =
712 graph->refiner.GetContext(node);
713 if (ic == nullptr) {
714 status->status =
715 InvalidArgument("Node ", node->name(), " was not found in the graph");
716 return -1;
717 }
718
719 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
720
721 // Unknown rank means the number of dimensions is -1.
722 if (!ic->RankKnown(shape)) {
723 return -1;
724 }
725
726 return ic->Rank(shape);
727 }
728
TF_GraphGetTensorShape(TF_Graph * graph,TF_Output output,int64_t * dims,int num_dims,TF_Status * status)729 void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,
730 int num_dims, TF_Status* status) {
731 Node* node = &output.oper->node;
732
733 mutex_lock l(graph->mu);
734 tensorflow::shape_inference::InferenceContext* ic =
735 graph->refiner.GetContext(node);
736 if (ic == nullptr) {
737 status->status =
738 InvalidArgument("Node ", node->name(), " was not found in the graph");
739 return;
740 }
741
742 tensorflow::shape_inference::ShapeHandle shape = ic->output(output.index);
743
744 int rank = -1;
745 if (ic->RankKnown(shape)) {
746 rank = ic->Rank(shape);
747 }
748
749 if (num_dims != rank) {
750 status->status = InvalidArgument("Expected rank is ", num_dims,
751 " but actual rank is ", rank);
752 return;
753 }
754
755 if (num_dims == 0) {
756 // Output shape is a scalar.
757 return;
758 }
759
760 // Rank is greater than 0, so fill in the values, if known, and
761 // -1 for unknown values.
762 for (int i = 0; i < num_dims; ++i) {
763 auto dim = ic->Dim(shape, i);
764 tensorflow::int64 value = -1;
765 if (ic->ValueKnown(dim)) {
766 value = ic->Value(dim);
767 }
768 dims[i] = value;
769 }
770 }
771
772 // TF_OperationDescription functions ------------------------------------------
773
774 extern "C" {
775
TF_NewOperationLocked(TF_Graph * graph,const char * op_type,const char * oper_name)776 static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
777 const char* op_type,
778 const char* oper_name)
779 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
780 return new TF_OperationDescription(graph, op_type, oper_name);
781 }
782
TF_NewOperation(TF_Graph * graph,const char * op_type,const char * oper_name)783 TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
784 const char* oper_name) {
785 mutex_lock l(graph->mu);
786 return TF_NewOperationLocked(graph, op_type, oper_name);
787 }
788
TF_SetDevice(TF_OperationDescription * desc,const char * device)789 void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
790 desc->node_builder.Device(device);
791 }
792
TF_AddInput(TF_OperationDescription * desc,TF_Output input)793 void TF_AddInput(TF_OperationDescription* desc, TF_Output input) {
794 desc->node_builder.Input(&input.oper->node, input.index);
795 }
796
TF_AddInputList(TF_OperationDescription * desc,const TF_Output * inputs,int num_inputs)797 void TF_AddInputList(TF_OperationDescription* desc, const TF_Output* inputs,
798 int num_inputs) {
799 std::vector<NodeBuilder::NodeOut> input_list;
800 input_list.reserve(num_inputs);
801 for (int i = 0; i < num_inputs; ++i) {
802 input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
803 }
804 desc->node_builder.Input(input_list);
805 }
806
TF_AddControlInput(TF_OperationDescription * desc,TF_Operation * input)807 void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
808 desc->node_builder.ControlInput(&input->node);
809 }
810
TF_ColocateWith(TF_OperationDescription * desc,TF_Operation * op)811 void TF_ColocateWith(TF_OperationDescription* desc, TF_Operation* op) {
812 desc->colocation_constraints.emplace(
813 StrCat(tensorflow::kColocationGroupPrefix, op->node.name()));
814 }
815
TF_SetAttrString(TF_OperationDescription * desc,const char * attr_name,const void * value,size_t length)816 void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
817 const void* value, size_t length) {
818 tensorflow::StringPiece s(static_cast<const char*>(value), length);
819 desc->node_builder.Attr(attr_name, s);
820 }
821
TF_SetAttrStringList(TF_OperationDescription * desc,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)822 void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
823 const void* const* values, const size_t* lengths,
824 int num_values) {
825 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
826 desc->colocation_constraints.clear();
827 for (int i = 0; i < num_values; ++i) {
828 desc->colocation_constraints.emplace(static_cast<const char*>(values[i]),
829 lengths[i]);
830 }
831 } else {
832 std::vector<tensorflow::StringPiece> v;
833 v.reserve(num_values);
834 for (int i = 0; i < num_values; ++i) {
835 v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
836 }
837 desc->node_builder.Attr(attr_name, v);
838 }
839 }
840
TF_SetAttrInt(TF_OperationDescription * desc,const char * attr_name,int64_t value)841 void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
842 int64_t value) {
843 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
844 "64-bit int types should match in size");
845 desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
846 }
847
TF_SetAttrIntList(TF_OperationDescription * desc,const char * attr_name,const int64_t * values,int num_values)848 void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
849 const int64_t* values, int num_values) {
850 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
851 "64-bit int types should match in size");
852 desc->node_builder.Attr(
853 attr_name,
854 ArraySlice<const tensorflow::int64>(
855 reinterpret_cast<const tensorflow::int64*>(values), num_values));
856 }
857
TF_SetAttrFloat(TF_OperationDescription * desc,const char * attr_name,float value)858 void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
859 float value) {
860 desc->node_builder.Attr(attr_name, value);
861 }
862
TF_SetAttrFloatList(TF_OperationDescription * desc,const char * attr_name,const float * values,int num_values)863 void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
864 const float* values, int num_values) {
865 desc->node_builder.Attr(attr_name,
866 ArraySlice<const float>(values, num_values));
867 }
868
TF_SetAttrBool(TF_OperationDescription * desc,const char * attr_name,unsigned char value)869 void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
870 unsigned char value) {
871 desc->node_builder.Attr(attr_name, static_cast<bool>(value));
872 }
873
TF_SetAttrBoolList(TF_OperationDescription * desc,const char * attr_name,const unsigned char * values,int num_values)874 void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
875 const unsigned char* values, int num_values) {
876 std::unique_ptr<bool[]> b(new bool[num_values]);
877 for (int i = 0; i < num_values; ++i) {
878 b[i] = values[i];
879 }
880 desc->node_builder.Attr(attr_name,
881 ArraySlice<const bool>(b.get(), num_values));
882 }
883
TF_SetAttrType(TF_OperationDescription * desc,const char * attr_name,TF_DataType value)884 void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
885 TF_DataType value) {
886 desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
887 }
888
TF_SetAttrTypeList(TF_OperationDescription * desc,const char * attr_name,const TF_DataType * values,int num_values)889 void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
890 const TF_DataType* values, int num_values) {
891 desc->node_builder.Attr(
892 attr_name, ArraySlice<const DataType>(
893 reinterpret_cast<const DataType*>(values), num_values));
894 }
895
TF_SetAttrPlaceholder(TF_OperationDescription * desc,const char * attr_name,const char * placeholder)896 void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
897 const char* placeholder) {
898 tensorflow::AttrValue attr_value;
899 attr_value.set_placeholder(placeholder);
900 desc->node_builder.Attr(attr_name, attr_value);
901 }
902
TF_SetAttrFuncName(TF_OperationDescription * desc,const char * attr_name,const char * value,size_t length)903 void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
904 const char* value, size_t length) {
905 tensorflow::NameAttrList func_name;
906 func_name.set_name(string(value, value + length));
907 desc->node_builder.Attr(attr_name, func_name);
908 }
909
TF_SetAttrShape(TF_OperationDescription * desc,const char * attr_name,const int64_t * dims,int num_dims)910 void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
911 const int64_t* dims, int num_dims) {
912 PartialTensorShape shape;
913 if (num_dims >= 0) {
914 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
915 "64-bit int types should match in size");
916 shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
917 reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
918 }
919 desc->node_builder.Attr(attr_name, shape);
920 }
921
TF_SetAttrShapeList(TF_OperationDescription * desc,const char * attr_name,const int64_t * const * dims,const int * num_dims,int num_shapes)922 void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
923 const int64_t* const* dims, const int* num_dims,
924 int num_shapes) {
925 std::vector<PartialTensorShape> shapes;
926 shapes.reserve(num_shapes);
927 for (int i = 0; i < num_shapes; ++i) {
928 if (num_dims[i] < 0) {
929 shapes.emplace_back();
930 } else {
931 static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
932 "64-bit int types should match in size");
933 shapes.emplace_back(ArraySlice<tensorflow::int64>(
934 reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
935 }
936 }
937 desc->node_builder.Attr(attr_name, shapes);
938 }
939
TF_SetAttrTensorShapeProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)940 void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
941 const char* attr_name, const void* proto,
942 size_t proto_len, TF_Status* status) {
943 // shape.ParseFromArray takes an int as length, this function takes size_t,
944 // make sure there is no information loss.
945 if (proto_len > std::numeric_limits<int>::max()) {
946 status->status = InvalidArgument(
947 "proto_len (", proto_len,
948 " bytes) is too large to be parsed by the protocol buffer library");
949 return;
950 }
951 TensorShapeProto shape;
952 if (shape.ParseFromArray(proto, static_cast<int>(proto_len))) {
953 desc->node_builder.Attr(attr_name, shape);
954 status->status = Status::OK();
955 } else {
956 status->status = InvalidArgument("Unparseable TensorShapeProto");
957 }
958 }
959
TF_SetAttrTensorShapeProtoList(TF_OperationDescription * desc,const char * attr_name,const void * const * protos,const size_t * proto_lens,int num_shapes,TF_Status * status)960 void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
961 const char* attr_name,
962 const void* const* protos,
963 const size_t* proto_lens, int num_shapes,
964 TF_Status* status) {
965 std::vector<TensorShapeProto> shapes;
966 shapes.resize(num_shapes);
967 for (int i = 0; i < num_shapes; ++i) {
968 if (proto_lens[i] > std::numeric_limits<int>::max()) {
969 status->status = InvalidArgument(
970 "length of element ", i, " in the list (", proto_lens[i],
971 " bytes) is too large to be parsed by the protocol buffer library");
972 return;
973 }
974 if (!shapes[i].ParseFromArray(protos[i], static_cast<int>(proto_lens[i]))) {
975 status->status =
976 InvalidArgument("Unparseable TensorShapeProto at index ", i);
977 return;
978 }
979 }
980 desc->node_builder.Attr(attr_name, shapes);
981 status->status = Status::OK();
982 }
983
TF_SetAttrTensor(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * value,TF_Status * status)984 void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
985 TF_Tensor* value, TF_Status* status) {
986 Tensor t;
987 status->status = TF_TensorToTensor(value, &t);
988 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
989 }
990
TF_SetAttrTensorList(TF_OperationDescription * desc,const char * attr_name,TF_Tensor * const * values,int num_values,TF_Status * status)991 void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
992 TF_Tensor* const* values, int num_values,
993 TF_Status* status) {
994 status->status = Status::OK();
995 std::vector<Tensor> t;
996 t.reserve(num_values);
997
998 for (int i = 0; i < num_values && status->status.ok(); ++i) {
999 Tensor v;
1000 status->status = TF_TensorToTensor(values[i], &v);
1001 t.emplace_back(v);
1002 }
1003
1004 if (status->status.ok()) desc->node_builder.Attr(attr_name, t);
1005 }
1006
TF_SetAttrValueProto(TF_OperationDescription * desc,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)1007 void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
1008 const void* proto, size_t proto_len,
1009 TF_Status* status) {
1010 tensorflow::AttrValue attr_value;
1011 if (!attr_value.ParseFromArray(proto, proto_len)) {
1012 status->status = InvalidArgument("Unparseable AttrValue proto");
1013 return;
1014 }
1015
1016 if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
1017 if (attr_value.value_case() != tensorflow::AttrValue::kList &&
1018 attr_value.value_case() != tensorflow::AttrValue::VALUE_NOT_SET) {
1019 status->status =
1020 InvalidArgument("Expected \"list\" field for \"",
1021 tensorflow::kColocationAttrName, "\" attribute");
1022 return;
1023 }
1024 desc->colocation_constraints.clear();
1025 for (const string& location : attr_value.list().s()) {
1026 desc->colocation_constraints.insert(location);
1027 }
1028 } else {
1029 desc->node_builder.Attr(attr_name, std::move(attr_value));
1030 }
1031
1032 status->status = Status::OK();
1033 }
1034
TF_FinishOperationLocked(TF_OperationDescription * desc,TF_Status * status)1035 static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
1036 TF_Status* status)
1037 TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
1038 Node* ret = nullptr;
1039
1040 if (desc->graph->name_map.count(desc->node_builder.node_name())) {
1041 status->status = InvalidArgument("Duplicate node name in graph: '",
1042 desc->node_builder.node_name(), "'");
1043 } else {
1044 if (!desc->colocation_constraints.empty()) {
1045 desc->node_builder.Attr(
1046 tensorflow::kColocationAttrName,
1047 std::vector<string>(desc->colocation_constraints.begin(),
1048 desc->colocation_constraints.end()));
1049 }
1050 status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret,
1051 /*consume=*/true);
1052
1053 if (status->status.ok()) {
1054 // Run shape inference function for newly added node.
1055 status->status = desc->graph->refiner.AddNode(ret);
1056 }
1057 if (status->status.ok()) {
1058 // Add the node to the name-to-node mapping.
1059 desc->graph->name_map[ret->name()] = ret;
1060 } else if (ret != nullptr) {
1061 desc->graph->graph.RemoveNode(ret);
1062 ret = nullptr;
1063 }
1064 }
1065
1066 delete desc;
1067
1068 return ToOperation(ret);
1069 }
1070
TF_FinishOperation(TF_OperationDescription * desc,TF_Status * status)1071 TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
1072 TF_Status* status) {
1073 mutex_lock l(desc->graph->mu);
1074 return TF_FinishOperationLocked(desc, status);
1075 }
1076
1077 // TF_Operation functions
1078 // ----------------------------------------------------------
1079
TF_OperationName(TF_Operation * oper)1080 const char* TF_OperationName(TF_Operation* oper) {
1081 return oper->node.name().c_str();
1082 }
1083
TF_OperationOpType(TF_Operation * oper)1084 const char* TF_OperationOpType(TF_Operation* oper) {
1085 return oper->node.type_string().c_str();
1086 }
1087
TF_OperationDevice(TF_Operation * oper)1088 const char* TF_OperationDevice(TF_Operation* oper) {
1089 return oper->node.requested_device().c_str();
1090 }
1091
TF_OperationNumOutputs(TF_Operation * oper)1092 int TF_OperationNumOutputs(TF_Operation* oper) {
1093 return oper->node.num_outputs();
1094 }
1095
TF_OperationOutputType(TF_Output oper_out)1096 TF_DataType TF_OperationOutputType(TF_Output oper_out) {
1097 return static_cast<TF_DataType>(
1098 oper_out.oper->node.output_type(oper_out.index));
1099 }
1100
TF_OperationOutputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1101 int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
1102 TF_Status* status) {
1103 NameRangeMap name_ranges;
1104 status->status =
1105 NameRangesForNode(oper->node, oper->node.op_def(), nullptr, &name_ranges);
1106 if (!status->status.ok()) return -1;
1107 auto iter = name_ranges.find(arg_name);
1108 if (iter == name_ranges.end()) {
1109 status->status = InvalidArgument("Output arg '", arg_name, "' not found");
1110 return -1;
1111 }
1112 return iter->second.second - iter->second.first;
1113 }
1114
TF_OperationNumInputs(TF_Operation * oper)1115 int TF_OperationNumInputs(TF_Operation* oper) {
1116 return oper->node.num_inputs();
1117 }
1118
TF_OperationInputType(TF_Input oper_in)1119 TF_DataType TF_OperationInputType(TF_Input oper_in) {
1120 return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
1121 }
1122
TF_OperationInputListLength(TF_Operation * oper,const char * arg_name,TF_Status * status)1123 int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
1124 TF_Status* status) {
1125 NameRangeMap name_ranges;
1126 status->status =
1127 NameRangesForNode(oper->node, oper->node.op_def(), &name_ranges, nullptr);
1128 if (!status->status.ok()) return -1;
1129 auto iter = name_ranges.find(arg_name);
1130 if (iter == name_ranges.end()) {
1131 status->status = InvalidArgument("Input arg '", arg_name, "' not found");
1132 return -1;
1133 }
1134 return iter->second.second - iter->second.first;
1135 }
1136
TF_OperationInput(TF_Input oper_in)1137 TF_Output TF_OperationInput(TF_Input oper_in) {
1138 const tensorflow::Edge* edge;
1139 Status s = oper_in.oper->node.input_edge(oper_in.index, &edge);
1140 if (!s.ok()) {
1141 return {nullptr, -1};
1142 }
1143
1144 return {ToOperation(edge->src()), edge->src_output()};
1145 }
1146
TF_OperationAllInputs(TF_Operation * oper,TF_Output * inputs,int max_inputs)1147 void TF_OperationAllInputs(TF_Operation* oper, TF_Output* inputs,
1148 int max_inputs) {
1149 for (auto* edge : oper->node.in_edges()) {
1150 if (edge->dst_input() >= 0 && edge->dst_input() < max_inputs) {
1151 inputs[edge->dst_input()] = {ToOperation(edge->src()),
1152 edge->src_output()};
1153 }
1154 }
1155 }
1156
TF_OperationOutputNumConsumers(TF_Output oper_out)1157 int TF_OperationOutputNumConsumers(TF_Output oper_out) {
1158 int count = 0;
1159 for (const auto* edge : oper_out.oper->node.out_edges()) {
1160 if (edge->src_output() == oper_out.index) {
1161 ++count;
1162 }
1163 }
1164 return count;
1165 }
1166
TF_OperationOutputConsumers(TF_Output oper_out,TF_Input * consumers,int max_consumers)1167 int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers,
1168 int max_consumers) {
1169 int count = 0;
1170 for (const auto* edge : oper_out.oper->node.out_edges()) {
1171 if (edge->src_output() == oper_out.index) {
1172 if (count < max_consumers) {
1173 consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
1174 }
1175 ++count;
1176 }
1177 }
1178 return count;
1179 }
1180
TF_OperationNumControlInputs(TF_Operation * oper)1181 int TF_OperationNumControlInputs(TF_Operation* oper) {
1182 int count = 0;
1183 for (const auto* edge : oper->node.in_edges()) {
1184 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1185 ++count;
1186 }
1187 }
1188 return count;
1189 }
1190
TF_OperationGetControlInputs(TF_Operation * oper,TF_Operation ** control_inputs,int max_control_inputs)1191 int TF_OperationGetControlInputs(TF_Operation* oper,
1192 TF_Operation** control_inputs,
1193 int max_control_inputs) {
1194 int count = 0;
1195 for (const auto* edge : oper->node.in_edges()) {
1196 if (edge->IsControlEdge() && !edge->src()->IsSource()) {
1197 if (count < max_control_inputs) {
1198 control_inputs[count] = ToOperation(edge->src());
1199 }
1200 ++count;
1201 }
1202 }
1203 return count;
1204 }
1205
TF_OperationNumControlOutputs(TF_Operation * oper)1206 int TF_OperationNumControlOutputs(TF_Operation* oper) {
1207 int count = 0;
1208 for (const auto* edge : oper->node.out_edges()) {
1209 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1210 ++count;
1211 }
1212 }
1213 return count;
1214 }
1215
TF_OperationGetControlOutputs(TF_Operation * oper,TF_Operation ** control_outputs,int max_control_outputs)1216 int TF_OperationGetControlOutputs(TF_Operation* oper,
1217 TF_Operation** control_outputs,
1218 int max_control_outputs) {
1219 int count = 0;
1220 for (const auto* edge : oper->node.out_edges()) {
1221 if (edge->IsControlEdge() && !edge->dst()->IsSink()) {
1222 if (count < max_control_outputs) {
1223 control_outputs[count] = ToOperation(edge->dst());
1224 }
1225 ++count;
1226 }
1227 }
1228 return count;
1229 }
1230
TF_OperationGetAttrMetadata(TF_Operation * oper,const char * attr_name,TF_Status * status)1231 TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper,
1232 const char* attr_name,
1233 TF_Status* status) {
1234 TF_AttrMetadata metadata;
1235 const auto* attr = GetAttrValue(oper, attr_name, status);
1236 if (!status->status.ok()) return metadata;
1237 switch (attr->value_case()) {
1238 #define SINGLE_CASE(kK, attr_type, size_expr) \
1239 case tensorflow::AttrValue::kK: \
1240 metadata.is_list = 0; \
1241 metadata.list_size = -1; \
1242 metadata.type = attr_type; \
1243 metadata.total_size = size_expr; \
1244 break;
1245
1246 SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
1247 SINGLE_CASE(kI, TF_ATTR_INT, -1);
1248 SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
1249 SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
1250 SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
1251 SINGLE_CASE(kShape, TF_ATTR_SHAPE,
1252 attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
1253 SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
1254 #undef SINGLE_CASE
1255
1256 case tensorflow::AttrValue::kList:
1257 metadata.is_list = 1;
1258 metadata.list_size = 0;
1259 metadata.total_size = -1;
1260 #define LIST_CASE(field, attr_type, ...) \
1261 if (attr->list().field##_size() > 0) { \
1262 metadata.type = attr_type; \
1263 metadata.list_size = attr->list().field##_size(); \
1264 __VA_ARGS__; \
1265 break; \
1266 }
1267
1268 LIST_CASE(
1269 s, TF_ATTR_STRING, metadata.total_size = 0;
1270 for (int i = 0; i < attr->list().s_size();
1271 ++i) { metadata.total_size += attr->list().s(i).size(); });
1272 LIST_CASE(i, TF_ATTR_INT);
1273 LIST_CASE(f, TF_ATTR_FLOAT);
1274 LIST_CASE(b, TF_ATTR_BOOL);
1275 LIST_CASE(type, TF_ATTR_TYPE);
1276 LIST_CASE(
1277 shape, TF_ATTR_SHAPE, metadata.total_size = 0;
1278 for (int i = 0; i < attr->list().shape_size(); ++i) {
1279 const auto& s = attr->list().shape(i);
1280 metadata.total_size += s.unknown_rank() ? 0 : s.dim_size();
1281 });
1282 LIST_CASE(tensor, TF_ATTR_TENSOR);
1283 LIST_CASE(tensor, TF_ATTR_FUNC);
1284 #undef LIST_CASE
1285 // All lists empty, determine the type from the OpDef.
1286 if (metadata.list_size == 0) {
1287 for (int i = 0; i < oper->node.op_def().attr_size(); ++i) {
1288 const auto& a = oper->node.op_def().attr(i);
1289 if (a.name() != attr_name) continue;
1290 const string& typestr = a.type();
1291 if (typestr == "list(string)") {
1292 metadata.type = TF_ATTR_STRING;
1293 } else if (typestr == "list(int)") {
1294 metadata.type = TF_ATTR_INT;
1295 } else if (typestr == "list(float)") {
1296 metadata.type = TF_ATTR_FLOAT;
1297 } else if (typestr == "list(bool)") {
1298 metadata.type = TF_ATTR_BOOL;
1299 } else if (typestr == "list(type)") {
1300 metadata.type = TF_ATTR_TYPE;
1301 } else if (typestr == "list(shape)") {
1302 metadata.type = TF_ATTR_SHAPE;
1303 } else if (typestr == "list(tensor)") {
1304 metadata.type = TF_ATTR_TENSOR;
1305 } else if (typestr == "list(func)") {
1306 metadata.type = TF_ATTR_FUNC;
1307 } else {
1308 status->status = InvalidArgument(
1309 "Attribute '", attr_name,
1310 "' has an empty value of an unrecognized type '", typestr, "'");
1311 return metadata;
1312 }
1313 }
1314 }
1315 break;
1316
1317 case tensorflow::AttrValue::kPlaceholder:
1318 metadata.is_list = 0;
1319 metadata.list_size = -1;
1320 metadata.type = TF_ATTR_PLACEHOLDER;
1321 metadata.total_size = -1;
1322 break;
1323
1324 case tensorflow::AttrValue::kFunc:
1325 metadata.is_list = 0;
1326 metadata.list_size = -1;
1327 metadata.type = TF_ATTR_FUNC;
1328 metadata.total_size = -1;
1329 break;
1330
1331 case tensorflow::AttrValue::VALUE_NOT_SET:
1332 status->status =
1333 InvalidArgument("Attribute '", attr_name, "' has no value set");
1334 break;
1335 }
1336 return metadata;
1337 }
1338
TF_OperationGetAttrString(TF_Operation * oper,const char * attr_name,void * value,size_t max_length,TF_Status * status)1339 void TF_OperationGetAttrString(TF_Operation* oper, const char* attr_name,
1340 void* value, size_t max_length,
1341 TF_Status* status) {
1342 const auto* attr = GetAttrValue(oper, attr_name, status);
1343 if (!status->status.ok()) return;
1344 if (attr->value_case() != tensorflow::AttrValue::kS) {
1345 status->status =
1346 InvalidArgument("Attribute '", attr_name, "' is not a string");
1347 return;
1348 }
1349 if (max_length <= 0) {
1350 return;
1351 }
1352 const auto& s = attr->s();
1353 std::memcpy(value, s.data(), std::min<size_t>(s.length(), max_length));
1354 }
1355
TF_OperationGetAttrStringList(TF_Operation * oper,const char * attr_name,void ** values,size_t * lengths,int max_values,void * storage,size_t storage_size,TF_Status * status)1356 void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name,
1357 void** values, size_t* lengths,
1358 int max_values, void* storage,
1359 size_t storage_size, TF_Status* status) {
1360 const auto* attr = GetAttrValue(oper, attr_name, status);
1361 if (!status->status.ok()) return;
1362 if (attr->value_case() != tensorflow::AttrValue::kList) {
1363 status->status =
1364 InvalidArgument("Value for '", attr_name, "' is not a list");
1365 return;
1366 }
1367 const auto len = std::min(max_values, attr->list().s_size());
1368 char* p = static_cast<char*>(storage);
1369 for (int i = 0; i < len; ++i) {
1370 const string& s = attr->list().s(i);
1371 values[i] = p;
1372 lengths[i] = s.size();
1373 if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
1374 status->status = InvalidArgument(
1375 "Not enough storage to hold the requested list of strings");
1376 return;
1377 }
1378 memcpy(values[i], s.data(), s.size());
1379 p += s.size();
1380 }
1381 }
1382
1383 #define DEFINE_GETATTR(func, c_type, cpp_type, list_field) \
1384 void func(TF_Operation* oper, const char* attr_name, c_type* value, \
1385 TF_Status* status) { \
1386 cpp_type v; \
1387 status->status = \
1388 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &v); \
1389 if (!status->status.ok()) return; \
1390 *value = static_cast<c_type>(v); \
1391 } \
1392 void func##List(TF_Operation* oper, const char* attr_name, c_type* values, \
1393 int max_values, TF_Status* status) { \
1394 const auto* attr = GetAttrValue(oper, attr_name, status); \
1395 if (!status->status.ok()) return; \
1396 if (attr->value_case() != tensorflow::AttrValue::kList) { \
1397 status->status = \
1398 InvalidArgument("Value for '", attr_name, "' is not a list."); \
1399 return; \
1400 } \
1401 const auto len = std::min(max_values, attr->list().list_field##_size()); \
1402 for (int i = 0; i < len; ++i) { \
1403 values[i] = static_cast<c_type>(attr->list().list_field(i)); \
1404 } \
1405 }
1406 DEFINE_GETATTR(TF_OperationGetAttrInt, int64_t, tensorflow::int64, i);
1407 DEFINE_GETATTR(TF_OperationGetAttrFloat, float, float, f);
1408 DEFINE_GETATTR(TF_OperationGetAttrBool, unsigned char, bool, b);
1409 DEFINE_GETATTR(TF_OperationGetAttrType, TF_DataType, DataType, type);
1410 #undef DEFINE_GETATTR
1411
TF_OperationGetAttrShape(TF_Operation * oper,const char * attr_name,int64_t * value,int num_dims,TF_Status * status)1412 void TF_OperationGetAttrShape(TF_Operation* oper, const char* attr_name,
1413 int64_t* value, int num_dims, TF_Status* status) {
1414 PartialTensorShape shape;
1415 status->status =
1416 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shape);
1417 if (!status->status.ok()) return;
1418 auto len = std::min(shape.dims(), num_dims);
1419 for (int i = 0; i < len; ++i) {
1420 value[i] = shape.dim_size(i);
1421 }
1422 }
1423
TF_OperationGetAttrShapeList(TF_Operation * oper,const char * attr_name,int64_t ** dims,int * num_dims,int num_shapes,int64_t * storage,int storage_size,TF_Status * status)1424 void TF_OperationGetAttrShapeList(TF_Operation* oper, const char* attr_name,
1425 int64_t** dims, int* num_dims, int num_shapes,
1426 int64_t* storage, int storage_size,
1427 TF_Status* status) {
1428 std::vector<PartialTensorShape> shapes;
1429 status->status =
1430 tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &shapes);
1431 if (!status->status.ok()) return;
1432 auto len = std::min(static_cast<int>(shapes.size()), num_shapes);
1433 int64_t* p = storage;
1434 int storage_left = storage_size;
1435 for (int i = 0; i < len; ++i) {
1436 // shapes[i].dims() == -1 for shapes with an unknown rank.
1437 int64_t n = shapes[i].dims();
1438 num_dims[i] = n;
1439 dims[i] = p;
1440 if (n < 0) {
1441 continue;
1442 }
1443 if (storage_left < n) {
1444 status->status = InvalidArgument(
1445 "Not enough storage to hold the requested list of shapes");
1446 return;
1447 }
1448 storage_left -= n;
1449 for (int j = 0; j < n; ++j, ++p) {
1450 *p = shapes[i].dim_size(j);
1451 }
1452 }
1453 }
1454
TF_OperationGetAttrTensorShapeProto(TF_Operation * oper,const char * attr_name,TF_Buffer * value,TF_Status * status)1455 void TF_OperationGetAttrTensorShapeProto(TF_Operation* oper,
1456 const char* attr_name,
1457 TF_Buffer* value, TF_Status* status) {
1458 const auto* attr = GetAttrValue(oper, attr_name, status);
1459 if (!status->status.ok()) return;
1460 if (attr->value_case() != tensorflow::AttrValue::kShape) {
1461 status->status =
1462 InvalidArgument("Value for '", attr_name, "' is not a shape.");
1463 return;
1464 }
1465 status->status = MessageToBuffer(attr->shape(), value);
1466 }
1467
TF_OperationGetAttrTensorShapeProtoList(TF_Operation * oper,const char * attr_name,TF_Buffer ** values,int max_values,TF_Status * status)1468 void TF_OperationGetAttrTensorShapeProtoList(TF_Operation* oper,
1469 const char* attr_name,
1470 TF_Buffer** values, int max_values,
1471 TF_Status* status) {
1472 const auto* attr = GetAttrValue(oper, attr_name, status);
1473 if (!status->status.ok()) return;
1474 if (attr->value_case() != tensorflow::AttrValue::kList) {
1475 status->status =
1476 InvalidArgument("Value for '", attr_name, "' is not a list");
1477 return;
1478 }
1479 const auto len = std::min(max_values, attr->list().shape_size());
1480 for (int i = 0; i < len; ++i) {
1481 values[i] = TF_NewBuffer();
1482 status->status = MessageToBuffer(attr->list().shape(i), values[i]);
1483 if (!status->status.ok()) {
1484 // Delete everything allocated to far, the operation has failed.
1485 for (int j = 0; j <= i; ++j) {
1486 TF_DeleteBuffer(values[j]);
1487 }
1488 return;
1489 }
1490 }
1491 }
1492
TF_OperationGetAttrTensor(TF_Operation * oper,const char * attr_name,TF_Tensor ** value,TF_Status * status)1493 void TF_OperationGetAttrTensor(TF_Operation* oper, const char* attr_name,
1494 TF_Tensor** value, TF_Status* status) {
1495 *value = nullptr;
1496 Tensor t;
1497 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &t);
1498 if (!status->status.ok()) return;
1499 *value = TF_TensorFromTensor(t, &status->status);
1500 }
1501
TF_OperationGetAttrTensorList(TF_Operation * oper,const char * attr_name,TF_Tensor ** values,int max_values,TF_Status * status)1502 void TF_OperationGetAttrTensorList(TF_Operation* oper, const char* attr_name,
1503 TF_Tensor** values, int max_values,
1504 TF_Status* status) {
1505 std::vector<Tensor> ts;
1506 status->status = tensorflow::GetNodeAttr(oper->node.attrs(), attr_name, &ts);
1507 if (!status->status.ok()) return;
1508 const auto len = std::min(max_values, static_cast<int>(ts.size()));
1509 for (int i = 0; i < len; ++i) {
1510 values[i] = TF_TensorFromTensor(ts[i], &status->status);
1511 }
1512 }
1513
TF_OperationGetAttrValueProto(TF_Operation * oper,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)1514 void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
1515 TF_Buffer* output_attr_value,
1516 TF_Status* status) {
1517 const auto* attr = GetAttrValue(oper, attr_name, status);
1518 if (!status->status.ok()) return;
1519 status->status = MessageToBuffer(*attr, output_attr_value);
1520 }
1521
TF_OperationToNodeDef(TF_Operation * oper,TF_Buffer * output_node_def,TF_Status * status)1522 void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
1523 TF_Status* status) {
1524 status->status = MessageToBuffer(oper->node.def(), output_node_def);
1525 }
1526
1527 // TF_Graph functions ---------------------------------------------------------
1528
TF_Graph()1529 TF_Graph::TF_Graph()
1530 : graph(tensorflow::OpRegistry::Global()),
1531 refiner(graph.versions().producer(), graph.op_registry()),
1532 delete_requested(false),
1533 parent(nullptr),
1534 parent_inputs(nullptr) {
1535 // Tell the shape refiner to also run shape inference on functions.
1536 refiner.set_function_library_for_shape_inference(&graph.flib_def());
1537 }
1538
TF_NewGraph()1539 TF_Graph* TF_NewGraph() { return new TF_Graph; }
1540
TF_DeleteGraph(TF_Graph * g)1541 void TF_DeleteGraph(TF_Graph* g) {
1542 if (g == nullptr) return;
1543 g->mu.lock();
1544 g->delete_requested = true;
1545 const bool del = g->sessions.empty();
1546 g->mu.unlock();
1547 if (del) delete g;
1548 }
1549
TF_GraphOperationByName(TF_Graph * graph,const char * oper_name)1550 TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
1551 mutex_lock l(graph->mu);
1552 auto iter = graph->name_map.find(oper_name);
1553 if (iter == graph->name_map.end()) {
1554 return nullptr;
1555 } else {
1556 return ToOperation(iter->second);
1557 }
1558 }
1559
TF_GraphNextOperation(TF_Graph * graph,size_t * pos)1560 TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
1561 if (*pos == 0) {
1562 // Advance past the first sentinel nodes in every graph (the source & sink).
1563 *pos += 2;
1564 } else {
1565 // Advance to the next node.
1566 *pos += 1;
1567 }
1568
1569 mutex_lock l(graph->mu);
1570 while (*pos < static_cast<size_t>(graph->graph.num_node_ids())) {
1571 Node* node = graph->graph.FindNodeId(*pos);
1572 // FindNodeId() returns nullptr for nodes that have been deleted.
1573 // We aren't currently allowing nodes to be deleted, but it is safer
1574 // to still check.
1575 if (node != nullptr) return ToOperation(node);
1576 *pos += 1;
1577 }
1578
1579 // No more nodes.
1580 return nullptr;
1581 }
1582
TF_GraphToGraphDef(TF_Graph * graph,TF_Buffer * output_graph_def,TF_Status * status)1583 void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
1584 TF_Status* status) {
1585 GraphDef def;
1586 {
1587 mutex_lock l(graph->mu);
1588 graph->graph.ToGraphDef(&def);
1589 }
1590 status->status = MessageToBuffer(def, output_graph_def);
1591 }
1592
TF_GraphGetOpDef(TF_Graph * graph,const char * op_name,TF_Buffer * output_op_def,TF_Status * status)1593 void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name,
1594 TF_Buffer* output_op_def, TF_Status* status) {
1595 const OpDef* op_def;
1596 {
1597 mutex_lock l(graph->mu);
1598 status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def);
1599 if (!status->status.ok()) return;
1600 }
1601 status->status = MessageToBuffer(*op_def, output_op_def);
1602 }
1603
TF_GraphVersions(TF_Graph * graph,TF_Buffer * output_version_def,TF_Status * status)1604 void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def,
1605 TF_Status* status) {
1606 VersionDef versions;
1607 {
1608 mutex_lock l(graph->mu);
1609 versions = graph->graph.versions();
1610 }
1611 status->status = MessageToBuffer(versions, output_version_def);
1612 }
1613
TF_NewImportGraphDefOptions()1614 TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
1615 return new TF_ImportGraphDefOptions;
1616 }
TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions * opts)1617 void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions* opts) {
1618 delete opts;
1619 }
TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions * opts,const char * prefix)1620 void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
1621 const char* prefix) {
1622 opts->opts.prefix = prefix;
1623 }
TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions * opts,const char * device)1624 void TF_ImportGraphDefOptionsSetDefaultDevice(TF_ImportGraphDefOptions* opts,
1625 const char* device) {
1626 opts->opts.default_device = device;
1627 }
1628
TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions * opts,unsigned char uniquify_names)1629 void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
1630 unsigned char uniquify_names) {
1631 opts->opts.uniquify_names = uniquify_names;
1632 }
1633
TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions * opts,unsigned char uniquify_prefix)1634 void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
1635 unsigned char uniquify_prefix) {
1636 opts->opts.uniquify_prefix = uniquify_prefix;
1637 }
1638
TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions * opts,const char * src_name,int src_index,TF_Output dst)1639 void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
1640 const char* src_name,
1641 int src_index, TF_Output dst) {
1642 opts->tensor_id_data.push_back(src_name);
1643 const string& src_name_str = opts->tensor_id_data.back();
1644 // We don't need to store dst's name in tensor_id_data, since `dst` must
1645 // outlive the ImportGraphDef call.
1646 opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst);
1647 }
1648
TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions * opts,const char * src_name,TF_Operation * dst)1649 void TF_ImportGraphDefOptionsRemapControlDependency(
1650 TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst) {
1651 opts->opts.input_map[TensorId(src_name, tensorflow::Graph::kControlSlot)] =
1652 TensorId(dst->node.name(), tensorflow::Graph::kControlSlot);
1653 }
1654
TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions * opts,TF_Operation * oper)1655 extern void TF_ImportGraphDefOptionsAddControlDependency(
1656 TF_ImportGraphDefOptions* opts, TF_Operation* oper) {
1657 opts->opts.control_dependencies.push_back(oper->node.name());
1658 }
1659
TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions * opts,const char * oper_name,int index)1660 void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts,
1661 const char* oper_name, int index) {
1662 opts->tensor_id_data.push_back(oper_name);
1663 const string& oper_name_str = opts->tensor_id_data.back();
1664 opts->opts.return_tensors.emplace_back(oper_name_str, index);
1665 }
1666
TF_ImportGraphDefOptionsNumReturnOutputs(const TF_ImportGraphDefOptions * opts)1667 int TF_ImportGraphDefOptionsNumReturnOutputs(
1668 const TF_ImportGraphDefOptions* opts) {
1669 return opts->opts.return_tensors.size();
1670 }
1671
TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions * opts,const char * oper_name)1672 void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts,
1673 const char* oper_name) {
1674 opts->opts.return_nodes.push_back(oper_name);
1675 }
1676
TF_ImportGraphDefOptionsNumReturnOperations(const TF_ImportGraphDefOptions * opts)1677 int TF_ImportGraphDefOptionsNumReturnOperations(
1678 const TF_ImportGraphDefOptions* opts) {
1679 return opts->opts.return_nodes.size();
1680 }
1681
TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults * results,int * num_outputs,TF_Output ** outputs)1682 void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results,
1683 int* num_outputs,
1684 TF_Output** outputs) {
1685 *num_outputs = results->return_tensors.size();
1686 *outputs = results->return_tensors.data();
1687 }
1688
TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults * results,int * num_opers,TF_Operation *** opers)1689 void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results,
1690 int* num_opers,
1691 TF_Operation*** opers) {
1692 *num_opers = results->return_nodes.size();
1693 *opers = results->return_nodes.data();
1694 }
1695
TF_ImportGraphDefResultsMissingUnusedInputMappings(TF_ImportGraphDefResults * results,int * num_missing_unused_input_mappings,const char *** src_names,int ** src_indexes)1696 void TF_ImportGraphDefResultsMissingUnusedInputMappings(
1697 TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings,
1698 const char*** src_names, int** src_indexes) {
1699 *num_missing_unused_input_mappings = results->missing_unused_key_names.size();
1700 *src_names = results->missing_unused_key_names.data();
1701 *src_indexes = results->missing_unused_key_indexes.data();
1702 }
1703
TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults * results)1704 void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) {
1705 delete results;
1706 }
1707
GraphImportGraphDefLocked(TF_Graph * graph,const GraphDef & def,const TF_ImportGraphDefOptions * opts,TF_ImportGraphDefResults * tf_results,TF_Status * status)1708 static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def,
1709 const TF_ImportGraphDefOptions* opts,
1710 TF_ImportGraphDefResults* tf_results,
1711 TF_Status* status)
1712 TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
1713 const int last_node_id = graph->graph.num_node_ids();
1714 tensorflow::ImportGraphDefResults results;
1715 status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph,
1716 &graph->refiner, &results);
1717 if (!status->status.ok()) return;
1718
1719 // Add new nodes to name_map
1720 for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) {
1721 auto* node = graph->graph.FindNodeId(i);
1722 if (node != nullptr) graph->name_map[node->name()] = node;
1723 }
1724
1725 // Populate return_tensors
1726 DCHECK(tf_results->return_tensors.empty());
1727 tf_results->return_tensors.resize(results.return_tensors.size());
1728 for (int i = 0; i < results.return_tensors.size(); ++i) {
1729 tf_results->return_tensors[i].oper =
1730 ToOperation(results.return_tensors[i].first);
1731 tf_results->return_tensors[i].index = results.return_tensors[i].second;
1732 }
1733
1734 // Populate return_nodes
1735 DCHECK(tf_results->return_nodes.empty());
1736 tf_results->return_nodes.resize(results.return_nodes.size());
1737 for (int i = 0; i < results.return_nodes.size(); ++i) {
1738 tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]);
1739 }
1740
1741 // Populate missing unused map keys
1742 DCHECK(tf_results->missing_unused_key_names.empty());
1743 DCHECK(tf_results->missing_unused_key_indexes.empty());
1744 DCHECK(tf_results->missing_unused_key_names_data.empty());
1745
1746 size_t size = results.missing_unused_input_map_keys.size();
1747 tf_results->missing_unused_key_names.resize(size);
1748 tf_results->missing_unused_key_indexes.resize(size);
1749
1750 for (int i = 0; i < size; ++i) {
1751 TensorId id = results.missing_unused_input_map_keys[i];
1752 tf_results->missing_unused_key_names_data.emplace_back(id.first);
1753 tf_results->missing_unused_key_names[i] =
1754 tf_results->missing_unused_key_names_data.back().c_str();
1755 tf_results->missing_unused_key_indexes[i] = id.second;
1756 }
1757 }
1758
TF_GraphImportGraphDefWithResults(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1759 TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
1760 TF_Graph* graph, const TF_Buffer* graph_def,
1761 const TF_ImportGraphDefOptions* options, TF_Status* status) {
1762 GraphDef def;
1763 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1764 graph_def->length)) {
1765 status->status = InvalidArgument("Invalid GraphDef");
1766 return nullptr;
1767 }
1768 auto results = new TF_ImportGraphDefResults();
1769 mutex_lock l(graph->mu);
1770 GraphImportGraphDefLocked(graph, def, options, results, status);
1771 if (!status->status.ok()) {
1772 delete results;
1773 return nullptr;
1774 }
1775 return results;
1776 }
1777
TF_GraphImportGraphDefWithReturnOutputs(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Output * return_outputs,int num_return_outputs,TF_Status * status)1778 void TF_GraphImportGraphDefWithReturnOutputs(
1779 TF_Graph* graph, const TF_Buffer* graph_def,
1780 const TF_ImportGraphDefOptions* options, TF_Output* return_outputs,
1781 int num_return_outputs, TF_Status* status) {
1782 if (num_return_outputs != options->opts.return_tensors.size()) {
1783 status->status = InvalidArgument("Expected 'num_return_outputs' to be ",
1784 options->opts.return_tensors.size(),
1785 ", got ", num_return_outputs);
1786 return;
1787 }
1788 if (num_return_outputs > 0 && return_outputs == nullptr) {
1789 status->status = InvalidArgument(
1790 "'return_outputs' must be preallocated to length ", num_return_outputs);
1791 return;
1792 }
1793 GraphDef def;
1794 if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
1795 graph_def->length)) {
1796 status->status = InvalidArgument("Invalid GraphDef");
1797 return;
1798 }
1799 TF_ImportGraphDefResults results;
1800 mutex_lock l(graph->mu);
1801 GraphImportGraphDefLocked(graph, def, options, &results, status);
1802 DCHECK_EQ(results.return_tensors.size(), num_return_outputs);
1803 memcpy(return_outputs, results.return_tensors.data(),
1804 num_return_outputs * sizeof(TF_Output));
1805 }
1806
TF_GraphImportGraphDef(TF_Graph * graph,const TF_Buffer * graph_def,const TF_ImportGraphDefOptions * options,TF_Status * status)1807 void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def,
1808 const TF_ImportGraphDefOptions* options,
1809 TF_Status* status) {
1810 TF_ImportGraphDefResults* results =
1811 TF_GraphImportGraphDefWithResults(graph, graph_def, options, status);
1812 TF_DeleteImportGraphDefResults(results);
1813 }
1814
1815 // While loop functions -------------------------------------------------------
1816
1817 namespace {
1818
1819 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1820
1821 // Creates a placeholder representing an input to the cond or body graph.
1822 // TODO(skyewm): remove these from final graph
CreateInput(const TF_Output & parent_input,TF_Graph * g,const char * name,TF_Output * input,TF_Status * status)1823 bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name,
1824 TF_Output* input, TF_Status* status) {
1825 TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name);
1826 TF_SetAttrType(desc, "dtype", TF_OperationOutputType(parent_input));
1827 // TODO(skyewm): set placeholder shape
1828 TF_Operation* oper = TF_FinishOperation(desc, status);
1829 if (!status->status.ok()) return false;
1830 *input = {oper, 0};
1831 return true;
1832 }
1833
1834 // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input
1835 // `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix`
1836 // will be prepended to copied node names. `control_deps` are nodes in
1837 // `dst_graph` that the copied `src_graph` nodes will have control dependencies
1838 // on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes
1839 // in `dst_graph` will be returned. `return_nodes` must be non-null.
CopyGraph(Graph * src_graph,Graph * dst_graph,tensorflow::ShapeRefiner * dst_refiner,const TF_Output * src_inputs,const std::vector<tensorflow::Output> & dst_inputs,const string & prefix,const std::vector<tensorflow::Operation> & control_deps,const TF_Output * nodes_to_return,int nreturn_nodes,std::vector<tensorflow::Output> * return_nodes)1840 Status CopyGraph(Graph* src_graph, Graph* dst_graph,
1841 tensorflow::ShapeRefiner* dst_refiner,
1842 const TF_Output* src_inputs,
1843 const std::vector<tensorflow::Output>& dst_inputs,
1844 const string& prefix,
1845 const std::vector<tensorflow::Operation>& control_deps,
1846 const TF_Output* nodes_to_return, int nreturn_nodes,
1847 std::vector<tensorflow::Output>* return_nodes) {
1848 DCHECK(return_nodes != nullptr);
1849 GraphDef gdef;
1850 src_graph->ToGraphDef(&gdef);
1851
1852 tensorflow::ImportGraphDefOptions opts;
1853 opts.prefix = prefix;
1854
1855 for (int i = 0; i < dst_inputs.size(); ++i) {
1856 opts.input_map[ToTensorId(src_inputs[i])] =
1857 TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index());
1858 }
1859 opts.skip_mapped_nodes = true;
1860
1861 for (const tensorflow::Operation& op : control_deps) {
1862 opts.control_dependencies.push_back(op.node()->name());
1863 }
1864
1865 for (int i = 0; i < nreturn_nodes; ++i) {
1866 opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
1867 }
1868
1869 // TODO(skyewm): change to OutputTensor
1870 tensorflow::ImportGraphDefResults results;
1871 TF_RETURN_IF_ERROR(
1872 ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
1873
1874 for (const auto& pair : results.return_tensors) {
1875 return_nodes->emplace_back(pair.first, pair.second);
1876 }
1877 return Status::OK();
1878 }
1879
ValidateConstWhileParams(const TF_WhileParams & params,TF_Status * s)1880 bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) {
1881 if (params.cond_graph == nullptr || params.body_graph == nullptr ||
1882 params.cond_graph->parent == nullptr ||
1883 params.cond_graph->parent != params.body_graph->parent ||
1884 params.cond_graph->parent_inputs != params.body_graph->parent_inputs ||
1885 params.ninputs <= 0 || params.cond_inputs == nullptr ||
1886 params.body_inputs == nullptr || params.body_outputs == nullptr) {
1887 s->status = InvalidArgument(
1888 "TF_WhileParams must be created by successful TF_NewWhile() call");
1889 return false;
1890 }
1891 return true;
1892 }
1893
ValidateInputWhileParams(const TF_WhileParams & params,TF_Status * s)1894 bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) {
1895 if (params.cond_output.oper == nullptr) {
1896 s->status = InvalidArgument("TF_WhileParams `cond_output` field isn't set");
1897 return false;
1898 }
1899 for (int i = 0; i < params.ninputs; ++i) {
1900 if (params.body_outputs[i].oper == nullptr) {
1901 s->status = InvalidArgument("TF_WhileParams `body_outputs[", i, "]` ",
1902 "field isn't set");
1903 return false;
1904 }
1905 }
1906 if (params.name == nullptr) {
1907 s->status = InvalidArgument("TF_WhileParams `name` field is null");
1908 return false;
1909 }
1910 return true;
1911 }
1912
1913 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1914
FreeWhileResources(const TF_WhileParams * params)1915 void FreeWhileResources(const TF_WhileParams* params) {
1916 TF_DeleteGraph(params->cond_graph);
1917 TF_DeleteGraph(params->body_graph);
1918 delete[] params->cond_inputs;
1919 delete[] params->body_inputs;
1920 delete[] params->body_outputs;
1921 }
1922
EmptyWhileParams()1923 TF_WhileParams EmptyWhileParams() {
1924 return {0, nullptr, nullptr, {nullptr, 0},
1925 nullptr, nullptr, nullptr, nullptr};
1926 }
1927
1928 } // namespace
1929
TF_NewWhile(TF_Graph * g,TF_Output * inputs,int ninputs,TF_Status * status)1930 TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs,
1931 TF_Status* status) {
1932 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1933 status->status = tensorflow::errors::Unimplemented(
1934 "Creating while loops is not supported on mobile. File a bug at "
1935 "https://github.com/tensorflow/tensorflow/issues if this feature is "
1936 "important to you");
1937 return EmptyWhileParams();
1938 #else
1939 if (ninputs == 0) {
1940 status->status =
1941 InvalidArgument("TF_NewWhile() must be passed at least one input");
1942 return EmptyWhileParams();
1943 }
1944
1945 TF_Graph* cond_graph = TF_NewGraph();
1946 TF_Graph* body_graph = TF_NewGraph();
1947 cond_graph->parent = g;
1948 cond_graph->parent_inputs = inputs;
1949 body_graph->parent = g;
1950 body_graph->parent_inputs = inputs;
1951
1952 TF_Output* cond_inputs = new TF_Output[ninputs];
1953 TF_Output cond_output = {nullptr, -1};
1954 TF_Output* body_inputs = new TF_Output[ninputs];
1955 TF_Output* body_outputs = new TF_Output[ninputs];
1956 for (int i = 0; i < ninputs; ++i) body_outputs[i] = {nullptr, -1};
1957 const char* name = nullptr;
1958
1959 for (int i = 0; i < ninputs; ++i) {
1960 // TODO(skyewm): prefix names with underscore (requires some plumbing)
1961 if (!CreateInput(inputs[i], cond_graph, StrCat("cond_input", i).c_str(),
1962 &cond_inputs[i], status)) {
1963 break;
1964 }
1965 if (!CreateInput(inputs[i], body_graph, StrCat("body_input", i).c_str(),
1966 &body_inputs[i], status)) {
1967 break;
1968 }
1969 }
1970
1971 TF_WhileParams params = {ninputs, cond_graph, cond_inputs, cond_output,
1972 body_graph, body_inputs, body_outputs, name};
1973
1974 if (!status->status.ok()) {
1975 FreeWhileResources(¶ms);
1976 return EmptyWhileParams();
1977 }
1978 return params;
1979 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
1980 }
1981
1982 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
1983 namespace {
1984
1985 // TODO(skyewm): make nodes in while loop unfetchable like in Python version
TF_FinishWhileHelper(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)1986 void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status,
1987 TF_Output* outputs) {
1988 if (!ValidateInputWhileParams(*params, status)) return;
1989
1990 TF_Graph* parent = params->cond_graph->parent;
1991 TF_Output* parent_inputs = params->cond_graph->parent_inputs;
1992 int num_loop_vars = params->ninputs;
1993
1994 mutex_lock l(parent->mu);
1995
1996 // 'cond_fn' copies the cond graph into the parent graph.
1997 tensorflow::ops::CondGraphBuilderFn cond_fn =
1998 [params, parent](const tensorflow::Scope& scope,
1999 const std::vector<tensorflow::Output>& inputs,
2000 tensorflow::Output* output) {
2001 DCHECK_EQ(scope.graph(), &parent->graph);
2002 std::vector<tensorflow::Output> cond_output;
2003 TF_RETURN_IF_ERROR(CopyGraph(
2004 ¶ms->cond_graph->graph, &parent->graph, &parent->refiner,
2005 params->cond_inputs, inputs, scope.impl()->name(),
2006 scope.impl()->control_deps(), ¶ms->cond_output,
2007 /* nreturn_nodes */ 1, &cond_output));
2008 *output = cond_output[0];
2009 return Status::OK();
2010 };
2011
2012 // 'body_fn' copies the body graph into the parent graph.
2013 tensorflow::ops::BodyGraphBuilderFn body_fn =
2014 [params, parent, num_loop_vars](
2015 const tensorflow::Scope& scope,
2016 const std::vector<tensorflow::Output>& inputs,
2017 std::vector<tensorflow::Output>* outputs) {
2018 DCHECK_EQ(scope.graph(), &parent->graph);
2019 TF_RETURN_IF_ERROR(
2020 CopyGraph(¶ms->body_graph->graph, &parent->graph,
2021 &parent->refiner, params->body_inputs, inputs,
2022 scope.impl()->name(), scope.impl()->control_deps(),
2023 params->body_outputs, num_loop_vars, outputs));
2024 return Status::OK();
2025 };
2026
2027 // Create the while loop using an internal scope.
2028 tensorflow::Scope scope =
2029 NewInternalScope(&parent->graph, &status->status, &parent->refiner)
2030 .NewSubScope(params->name);
2031
2032 const int first_new_node_id = parent->graph.num_node_ids();
2033
2034 tensorflow::OutputList loop_outputs;
2035 status->status = tensorflow::ops::BuildWhileLoop(
2036 scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn,
2037 body_fn, params->name, &loop_outputs);
2038
2039 // Update name_map with newly-created ops.
2040 // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns
2041 // a bad status. Once we fix this, we may want to return early instead of
2042 // executing the following code.
2043 for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) {
2044 Node* new_node = parent->graph.FindNodeId(i);
2045 if (new_node == nullptr) continue;
2046 parent->name_map[new_node->name()] = new_node;
2047 }
2048
2049 // Populate 'outputs'.
2050 DCHECK_LE(loop_outputs.size(), num_loop_vars);
2051 for (int i = 0; i < loop_outputs.size(); ++i) {
2052 outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()};
2053 }
2054 }
2055
2056 } // namespace
2057 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2058
TF_FinishWhile(const TF_WhileParams * params,TF_Status * status,TF_Output * outputs)2059 void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
2060 TF_Output* outputs) {
2061 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2062 status->status = tensorflow::errors::Unimplemented(
2063 "Creating while loops is not supported on mobile. File a bug at "
2064 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2065 "important to you");
2066 #else
2067 // If it appears the caller created or modified `params`, don't free resources
2068 if (!ValidateConstWhileParams(*params, status)) return;
2069 TF_FinishWhileHelper(params, status, outputs);
2070 FreeWhileResources(params);
2071 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2072 }
2073
TF_AbortWhile(const TF_WhileParams * params)2074 void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
2075
TF_AddGradients(TF_Graph * g,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2076 void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
2077 TF_Output* dx, TF_Status* status, TF_Output* dy) {
2078 TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
2079 }
2080
TF_AddGradientsWithPrefix(TF_Graph * g,const char * prefix,TF_Output * y,int ny,TF_Output * x,int nx,TF_Output * dx,TF_Status * status,TF_Output * dy)2081 void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
2082 int ny, TF_Output* x, int nx, TF_Output* dx,
2083 TF_Status* status, TF_Output* dy) {
2084 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2085 status->status = tensorflow::errors::Unimplemented(
2086 "Adding gradients is not supported on mobile. File a bug at "
2087 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2088 "important to you");
2089 #else
2090 std::vector<tensorflow::Output> y_arg = OutputsFromTFOutputs(y, ny);
2091 std::vector<tensorflow::Output> x_arg = OutputsFromTFOutputs(x, nx);
2092 std::vector<tensorflow::Output> dy_arg;
2093
2094 {
2095 // We need to hold on to the lock while we have a scope that uses TF_Graph.
2096 mutex_lock graph_lock(g->mu);
2097
2098 const int first_new_node_id = g->graph.num_node_ids();
2099
2100 string prefix_cmp;
2101 const char* child_scope_name;
2102 if (prefix == nullptr) {
2103 child_scope_name = "gradients";
2104 } else {
2105 prefix_cmp = string(prefix) + "/";
2106 // The operation should fail if the provided name prefix has already been
2107 // used in this graph
2108 for (const auto& pair : g->name_map) {
2109 const string& name = pair.first;
2110 if ((name == prefix) || absl::StartsWith(name, prefix_cmp)) {
2111 status->status = InvalidArgument(
2112 "prefix [", prefix,
2113 "] conflicts with existing node in the graph named [", name, "]");
2114 return;
2115 }
2116 }
2117 child_scope_name = prefix;
2118 }
2119 tensorflow::Scope scope =
2120 NewInternalScope(&g->graph, &status->status, &g->refiner)
2121 .NewSubScope(child_scope_name);
2122
2123 if (dx != nullptr) {
2124 std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
2125 status->status =
2126 AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg);
2127 } else {
2128 status->status = AddSymbolicGradients(scope, y_arg, x_arg, &dy_arg);
2129 }
2130
2131 // Update g->name_map with the name_map from the scope, which will contain
2132 // the new gradient ops.
2133 for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
2134 Node* n = g->graph.FindNodeId(i);
2135 if (n == nullptr) continue;
2136
2137 // Adding the gradients to the graph can alter the prefix to prevent
2138 // name collisions only if this prefix has not been provided explicitly
2139 // by the user. If it was provided, assert that it remained intact.
2140 if (prefix != nullptr && !absl::StartsWith(n->name(), prefix_cmp)) {
2141 status->status = tensorflow::errors::Internal(
2142 "BUG: The gradients prefix have been unexpectedly altered when "
2143 "adding the nodes to the graph. This is a bug. Please file an "
2144 "issue at https://github.com/tensorflow/tensorflow/issues.");
2145 return;
2146 }
2147 // We have a convoluted scheme here: Using the C++ graph construction API
2148 // to add potentially many nodes to the graph without running the checks
2149 // (such as uniqueness of the names of nodes) we run with other functions
2150 // that add a node to the graph (like TF_FinishOperation).
2151 if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
2152 status->status = tensorflow::errors::Internal(
2153 "BUG: The API allowed construction of a graph with duplicate node "
2154 "names (",
2155 n->name(),
2156 "). This is a bug. Please file an issue at "
2157 "https://github.com/tensorflow/tensorflow/issues.");
2158 }
2159 }
2160 }
2161
2162 // Unpack the results from grad_outputs_arg.
2163 TFOutputsFromOutputs(dy_arg, dy);
2164 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2165 }
2166
2167 // TF_Session functions ----------------------------------------------
2168
TF_Session(tensorflow::Session * s,TF_Graph * g)2169 TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
2170 : session(s), graph(g), last_num_graph_nodes(0), extend_before_run(true) {}
2171
TF_NewSession(TF_Graph * graph,const TF_SessionOptions * opt,TF_Status * status)2172 TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
2173 TF_Status* status) {
2174 Session* session;
2175 status->status = NewSession(opt->options, &session);
2176 if (status->status.ok()) {
2177 TF_Session* new_session = new TF_Session(session, graph);
2178 if (graph != nullptr) {
2179 mutex_lock l(graph->mu);
2180 graph->sessions[new_session] = "";
2181 }
2182 return new_session;
2183 } else {
2184 LOG(ERROR) << status->status;
2185 DCHECK_EQ(nullptr, session);
2186 return nullptr;
2187 }
2188 }
2189
TF_LoadSessionFromSavedModel(const TF_SessionOptions * session_options,const TF_Buffer * run_options,const char * export_dir,const char * const * tags,int tags_len,TF_Graph * graph,TF_Buffer * meta_graph_def,TF_Status * status)2190 TF_Session* TF_LoadSessionFromSavedModel(
2191 const TF_SessionOptions* session_options, const TF_Buffer* run_options,
2192 const char* export_dir, const char* const* tags, int tags_len,
2193 TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status) {
2194 // TODO(sjr): Remove the IS_MOBILE_PLATFORM guard. This will require ensuring
2195 // that the tensorflow/cc/saved_model:loader build target is mobile friendly.
2196 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2197 status->status = tensorflow::errors::Unimplemented(
2198 "Loading a SavedModel is not supported on mobile. File a bug at "
2199 "https://github.com/tensorflow/tensorflow/issues if this feature is "
2200 "important to you");
2201 return nullptr;
2202 #else
2203 mutex_lock l(graph->mu);
2204 if (!graph->name_map.empty()) {
2205 status->status = InvalidArgument("Graph is non-empty.");
2206 return nullptr;
2207 }
2208
2209 RunOptions run_options_proto;
2210 if (run_options != nullptr && !run_options_proto.ParseFromArray(
2211 run_options->data, run_options->length)) {
2212 status->status = InvalidArgument("Unparseable RunOptions proto");
2213 return nullptr;
2214 }
2215
2216 std::unordered_set<string> tag_set;
2217 for (int i = 0; i < tags_len; i++) {
2218 tag_set.insert(string(tags[i]));
2219 }
2220
2221 tensorflow::SavedModelBundle bundle;
2222 status->status =
2223 tensorflow::LoadSavedModel(session_options->options, run_options_proto,
2224 export_dir, tag_set, &bundle);
2225 if (!status->status.ok()) return nullptr;
2226
2227 // Create a TF_Graph from the MetaGraphDef. This is safe as long as Session
2228 // extends using GraphDefs. The Graph instance is different, but equivalent
2229 // to the one used to create the session.
2230 //
2231 // TODO(jhseu): When Session is modified to take Graphs instead of
2232 // GraphDefs, return the Graph generated in LoadSavedModel().
2233 TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions();
2234 TF_ImportGraphDefResults results;
2235 GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(),
2236 import_opts, &results, status);
2237 TF_DeleteImportGraphDefOptions(import_opts);
2238 if (!status->status.ok()) return nullptr;
2239
2240 if (meta_graph_def != nullptr) {
2241 status->status = MessageToBuffer(bundle.meta_graph_def, meta_graph_def);
2242 if (!status->status.ok()) return nullptr;
2243 }
2244
2245 TF_Session* session = new TF_Session(bundle.session.release(), graph);
2246
2247 graph->sessions[session] = "";
2248 session->last_num_graph_nodes = graph->graph.num_node_ids();
2249 return session;
2250 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2251 }
2252
TF_CloseSession(TF_Session * s,TF_Status * status)2253 void TF_CloseSession(TF_Session* s, TF_Status* status) {
2254 status->status = s->session->Close();
2255 }
2256
TF_DeleteSession(TF_Session * s,TF_Status * status)2257 void TF_DeleteSession(TF_Session* s, TF_Status* status) {
2258 status->status = Status::OK();
2259 if (s == nullptr) return;
2260 TF_Graph* const graph = s->graph;
2261 if (graph != nullptr) {
2262 graph->mu.lock();
2263 graph->sessions.erase(s);
2264 const bool del = graph->delete_requested && graph->sessions.empty();
2265 graph->mu.unlock();
2266 if (del) delete graph;
2267 }
2268 delete s->session;
2269 delete s;
2270 }
2271
TF_SessionRun(TF_Session * session,const TF_Buffer * run_options,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Buffer * run_metadata,TF_Status * status)2272 void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options,
2273 const TF_Output* inputs, TF_Tensor* const* input_values,
2274 int ninputs, const TF_Output* outputs,
2275 TF_Tensor** output_values, int noutputs,
2276 const TF_Operation* const* target_opers, int ntargets,
2277 TF_Buffer* run_metadata, TF_Status* status) {
2278 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2279 // directly, instead of requiring us to serialize to a GraphDef and
2280 // call Session::Extend().
2281 if (session->extend_before_run &&
2282 !ExtendSessionGraphHelper(session, status)) {
2283 return;
2284 }
2285
2286 TF_Run_Setup(noutputs, output_values, status);
2287
2288 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2289 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2290 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2291 for (int i = 0; i < ninputs; ++i) {
2292 input_pairs[i].first = OutputName(inputs[i]);
2293 }
2294
2295 // Convert from TF_Output to string names.
2296 std::vector<string> output_names(noutputs);
2297 for (int i = 0; i < noutputs; ++i) {
2298 output_names[i] = OutputName(outputs[i]);
2299 }
2300
2301 // Convert from TF_Operation* to string names.
2302 std::vector<string> target_names(ntargets);
2303 for (int i = 0; i < ntargets; ++i) {
2304 target_names[i] = target_opers[i]->node.name();
2305 }
2306
2307 // Actually run.
2308 TF_Run_Helper(session->session, nullptr, run_options, input_pairs,
2309 output_names, output_values, target_names, run_metadata,
2310 status);
2311 }
2312
TF_SessionPRunSetup(TF_Session * session,const TF_Output * inputs,int ninputs,const TF_Output * outputs,int noutputs,const TF_Operation * const * target_opers,int ntargets,const char ** handle,TF_Status * status)2313 void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs,
2314 int ninputs, const TF_Output* outputs, int noutputs,
2315 const TF_Operation* const* target_opers, int ntargets,
2316 const char** handle, TF_Status* status) {
2317 *handle = nullptr;
2318
2319 if (session->extend_before_run &&
2320 !ExtendSessionGraphHelper(session, status)) {
2321 return;
2322 }
2323
2324 std::vector<string> input_names(ninputs);
2325 for (int i = 0; i < ninputs; ++i) {
2326 input_names[i] = OutputName(inputs[i]);
2327 }
2328
2329 std::vector<string> output_names(noutputs);
2330 for (int i = 0; i < noutputs; ++i) {
2331 output_names[i] = OutputName(outputs[i]);
2332 }
2333
2334 std::vector<string> target_names(ntargets);
2335 for (int i = 0; i < ntargets; ++i) {
2336 target_names[i] = target_opers[i]->node.name();
2337 }
2338
2339 string new_handle;
2340 status->status = session->session->PRunSetup(input_names, output_names,
2341 target_names, &new_handle);
2342 if (status->status.ok()) {
2343 char* buf = new char[new_handle.size() + 1];
2344 memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
2345 *handle = buf;
2346 }
2347 }
2348
TF_DeletePRunHandle(const char * handle)2349 void TF_DeletePRunHandle(const char* handle) {
2350 delete[] handle;
2351 // TODO(suharshs): Free up any resources held by the partial run state.
2352 }
2353
TF_SessionPRun(TF_Session * session,const char * handle,const TF_Output * inputs,TF_Tensor * const * input_values,int ninputs,const TF_Output * outputs,TF_Tensor ** output_values,int noutputs,const TF_Operation * const * target_opers,int ntargets,TF_Status * status)2354 void TF_SessionPRun(TF_Session* session, const char* handle,
2355 const TF_Output* inputs, TF_Tensor* const* input_values,
2356 int ninputs, const TF_Output* outputs,
2357 TF_Tensor** output_values, int noutputs,
2358 const TF_Operation* const* target_opers, int ntargets,
2359 TF_Status* status) {
2360 // TODO(josh11b,mrry): Change Session to be able to use a Graph*
2361 // directly, instead of requiring us to serialize to a GraphDef and
2362 // call Session::Extend().
2363 if (session->extend_before_run &&
2364 !ExtendSessionGraphHelper(session, status)) {
2365 return;
2366 }
2367
2368 TF_Run_Setup(noutputs, output_values, status);
2369
2370 // Convert from TF_Output and TF_Tensor to a string and Tensor.
2371 std::vector<std::pair<string, Tensor>> input_pairs(ninputs);
2372 if (!TF_Run_Inputs(input_values, &input_pairs, status)) return;
2373 for (int i = 0; i < ninputs; ++i) {
2374 input_pairs[i].first = OutputName(inputs[i]);
2375 }
2376
2377 // Convert from TF_Output to string names.
2378 std::vector<string> output_names(noutputs);
2379 for (int i = 0; i < noutputs; ++i) {
2380 output_names[i] = OutputName(outputs[i]);
2381 }
2382
2383 // Convert from TF_Operation* to string names.
2384 std::vector<string> target_names(ntargets);
2385 for (int i = 0; i < ntargets; ++i) {
2386 target_names[i] = target_opers[i]->node.name();
2387 }
2388
2389 TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,
2390 output_values, target_names, nullptr, status);
2391 }
2392
TF_TryEvaluateConstant(TF_Graph * graph,TF_Output output,TF_Tensor ** result,TF_Status * status)2393 unsigned char TF_TryEvaluateConstant(TF_Graph* graph, TF_Output output,
2394 TF_Tensor** result, TF_Status* status) {
2395 *result = nullptr;
2396 mutex_lock l(graph->mu);
2397 OutputTensor tensor(&output.oper->node, output.index);
2398 bool evaluated;
2399 Tensor result_tensor;
2400 status->status = EvaluateConstantTensor(
2401 tensor, graph->refiner, *graph->graph.op_registry(),
2402 graph->graph.versions().producer(), &evaluated, &result_tensor);
2403 if (evaluated) {
2404 DCHECK(status->status.ok());
2405 *result = TF_TensorFromTensor(result_tensor, &status->status);
2406 if (!status->status.ok()) evaluated = false;
2407 }
2408 return evaluated;
2409 }
2410
TF_NewApiDefMap(TF_Buffer * op_list_buffer,TF_Status * status)2411 TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) {
2412 tensorflow::OpList op_list;
2413 if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) {
2414 status->status = InvalidArgument("Unparseable OpList");
2415 return nullptr;
2416 }
2417 status->status = Status::OK();
2418 return new TF_ApiDefMap(op_list);
2419 }
2420
TF_DeleteApiDefMap(TF_ApiDefMap * apimap)2421 void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; }
2422
TF_ApiDefMapPut(TF_ApiDefMap * api_def_map,const char * text,size_t text_len,TF_Status * status)2423 void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text,
2424 size_t text_len, TF_Status* status) {
2425 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2426 status->status = tensorflow::errors::Unimplemented(
2427 "ApiDefMap is not supported on mobile.");
2428 #else
2429 mutex_lock l(api_def_map->lock);
2430 if (api_def_map->update_docs_called) {
2431 status->status = FailedPrecondition(
2432 "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been "
2433 "called.");
2434 return;
2435 }
2436 string api_def_text(text, text_len);
2437 status->status = api_def_map->api_def_map.LoadApiDef(api_def_text);
2438 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2439 }
2440
TF_ApiDefMapGet(TF_ApiDefMap * api_def_map,const char * name,size_t name_len,TF_Status * status)2441 TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name,
2442 size_t name_len, TF_Status* status) {
2443 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2444 status->status = tensorflow::errors::Unimplemented(
2445 "ApiDefMap is not supported on mobile.");
2446 return nullptr;
2447 #else
2448 mutex_lock l(api_def_map->lock);
2449 if (!api_def_map->update_docs_called) {
2450 api_def_map->api_def_map.UpdateDocs();
2451 api_def_map->update_docs_called = true;
2452 }
2453 string name_str(name, name_len);
2454 const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str);
2455 if (api_def == nullptr) {
2456 return nullptr;
2457 }
2458
2459 TF_Buffer* ret = TF_NewBuffer();
2460 status->status = MessageToBuffer(*api_def, ret);
2461 if (!status->status.ok()) {
2462 TF_DeleteBuffer(ret);
2463 return nullptr;
2464 }
2465 return ret;
2466 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2467 }
2468
TF_GetAllRegisteredKernels(TF_Status * status)2469 TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status) {
2470 tensorflow::KernelList kernel_list = tensorflow::GetAllRegisteredKernels();
2471 TF_Buffer* ret = TF_NewBuffer();
2472 status->status = MessageToBuffer(kernel_list, ret);
2473 if (!status->status.ok()) {
2474 TF_DeleteBuffer(ret);
2475 return nullptr;
2476 }
2477 return ret;
2478 }
2479
TF_GetRegisteredKernelsForOp(const char * name,TF_Status * status)2480 TF_Buffer* TF_GetRegisteredKernelsForOp(const char* name, TF_Status* status) {
2481 tensorflow::KernelList kernel_list =
2482 tensorflow::GetRegisteredKernelsForOp(name);
2483 TF_Buffer* ret = TF_NewBuffer();
2484 status->status = MessageToBuffer(kernel_list, ret);
2485 if (!status->status.ok()) {
2486 TF_DeleteBuffer(ret);
2487 return nullptr;
2488 }
2489 return ret;
2490 }
2491
TF_UpdateEdge(TF_Graph * graph,TF_Output new_src,TF_Input dst,TF_Status * status)2492 void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2493 TF_Status* status) {
2494 using tensorflow::RecordMutation;
2495 mutex_lock l(graph->mu);
2496 tensorflow::shape_inference::InferenceContext* ic =
2497 graph->refiner.GetContext(&new_src.oper->node);
2498
2499 if (ic->num_outputs() <= new_src.index) {
2500 status->status = tensorflow::errors::OutOfRange(
2501 "Cannot update edge. Output index [", new_src.index,
2502 "] is greater than the number of total outputs [", ic->num_outputs(),
2503 "].");
2504 return;
2505 }
2506 tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);
2507
2508 tensorflow::shape_inference::InferenceContext* ic_dst =
2509 graph->refiner.GetContext(&dst.oper->node);
2510 if (ic_dst->num_inputs() <= dst.index) {
2511 status->status = tensorflow::errors::OutOfRange(
2512 "Cannot update edge. Input index [", dst.index,
2513 "] is greater than the number of total inputs [", ic_dst->num_inputs(),
2514 "].");
2515 return;
2516 }
2517 if (!ic_dst->MergeInput(dst.index, shape)) {
2518 status->status = tensorflow::errors::InvalidArgument(
2519 "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
2520 " and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
2521 return;
2522 }
2523 status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
2524 &dst.oper->node, dst.index);
2525
2526 if (TF_GetCode(status) == TF_OK) {
2527 // This modification only updates the destination node for
2528 // the purposes of running this graph in a session. Thus, we don't
2529 // record the source node as being modified.
2530 RecordMutation(graph, *dst.oper, "updating input tensor");
2531 }
2532 }
2533
2534 // TF_Server functions ----------------------------------------------
2535
2536 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)2537 TF_Server::TF_Server(std::unique_ptr<tensorflow::ServerInterface> server)
2538 : target(server->target()), server(std::move(server)) {}
2539 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2540
TF_NewServer(const void * proto,size_t proto_len,TF_Status * status)2541 TF_Server* TF_NewServer(const void* proto, size_t proto_len,
2542 TF_Status* status) {
2543 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2544 status->status = tensorflow::errors::Unimplemented(
2545 "Server functionality is not supported on mobile");
2546 return nullptr;
2547 #else
2548 tensorflow::ServerDef server_def;
2549 if (!server_def.ParseFromArray(proto, static_cast<int>(proto_len))) {
2550 status->status = InvalidArgument(
2551 "Could not parse provided bytes into a ServerDef protocol buffer");
2552 return nullptr;
2553 }
2554
2555 std::unique_ptr<tensorflow::ServerInterface> out_server;
2556 status->status = tensorflow::NewServer(server_def, &out_server);
2557 if (!status->status.ok()) return nullptr;
2558
2559 return new TF_Server(std::move(out_server));
2560 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2561 }
2562
TF_ServerStart(TF_Server * server,TF_Status * status)2563 void TF_ServerStart(TF_Server* server, TF_Status* status) {
2564 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2565 status->status = tensorflow::errors::Unimplemented(
2566 "Server functionality is not supported on mobile");
2567 #else
2568 status->status = server->server->Start();
2569 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2570 }
2571
TF_ServerStop(TF_Server * server,TF_Status * status)2572 void TF_ServerStop(TF_Server* server, TF_Status* status) {
2573 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2574 status->status = tensorflow::errors::Unimplemented(
2575 "Server functionality is not supported on mobile");
2576 #else
2577 status->status = server->server->Stop();
2578 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2579 }
2580
TF_ServerJoin(TF_Server * server,TF_Status * status)2581 void TF_ServerJoin(TF_Server* server, TF_Status* status) {
2582 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2583 status->status = tensorflow::errors::Unimplemented(
2584 "Server functionality is not supported on mobile");
2585 #else
2586 status->status = server->server->Join();
2587 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2588 }
2589
TF_ServerTarget(TF_Server * server)2590 const char* TF_ServerTarget(TF_Server* server) {
2591 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2592 return nullptr;
2593 #else
2594 return server->target.c_str();
2595 #endif
2596 }
2597
TF_DeleteServer(TF_Server * server)2598 void TF_DeleteServer(TF_Server* server) {
2599 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2600 delete server;
2601 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2602 }
2603
TF_RegisterLogListener(void (* listener)(const char *))2604 void TF_RegisterLogListener(void (*listener)(const char*)) {
2605 #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2606 tensorflow::logging::RegisterListener(listener);
2607 #endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
2608 }
2609
TF_RegisterFilesystemPlugin(const char * plugin_filename,TF_Status * status)2610 void TF_RegisterFilesystemPlugin(const char* plugin_filename,
2611 TF_Status* status) {
2612 #if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2613 status->status = tensorflow::errors::Unimplemented(
2614 "FileSystem plugin functionality is not supported on mobile");
2615 #else
2616 status->status = tensorflow::RegisterFilesystemPlugin(plugin_filename);
2617 #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD)
2618 }
2619
2620 } // end extern "C"
2621