1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api_internal.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <unordered_set>
21
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/tensor.pb.h" // NOLINT
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/lib/strings/base64.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31
32 using tensorflow::errors::InvalidArgument;
33
34 namespace tensorflow {
35 namespace {
36
37 // Class that maintains a one-to-one original node name -> new node name
38 // mapping. We normalize the names used as input and output arguments to match
39 // regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
40 // Once we rename them, we risk creating a name collision with the other
41 // node names, so if necessary we add a suffix to make
42 // names unique. If we have an input named "A" and a node in the function
43 // body named "a", they will be renamed to "a" and "a_0".
44 class NodeNameMapping {
45 public:
46 NodeNameMapping() = default;
47
48 // Normalize the input name and make it unique. This is the same as the
49 // function for output, expect that it adds a name mapping for the name.
50 string GetInputName(const string& name);
51
52 // Normalize the output name and make it unique.
53 string GetOutputName(const string& name);
54
55 // Make the node name unique.
56 string Uniquify(const string& name);
57
58 // Records name as a used name. If this name is already used,
59 // returns an error status.
60 Status UseOutputName(const string& name);
61
62 // Look up how a node name was previously normalized/uniquified.
63 // Returns empty if name was never seen.
64 string Lookup(const string& name) const;
65
66 private:
67 string UniquifyHelper(const string& name) const;
68 static string Normalize(string name);
69
70 // The normalized/uniquified names already used as
71 // input names (in signature), output names (in signature), and node names
72 // (in node_def).
73 // This is a superset of values in name_mapping_.
74 std::unordered_set<string> used_names_;
75 // Mapping from original node name from the graph to the normalized
76 // and uniquified version of it.
77 std::unordered_map<string, string> name_mapping_;
78 };
79
Normalize(string name)80 string NodeNameMapping::Normalize(string name) {
81 // Convert letters to lowercase and non-alphanumeric characters to '_'.
82 if (name.empty()) return "unknown";
83 const int n = name.size();
84 for (int i = 0; i < n; ++i) {
85 char c = name[i];
86 if (isalnum(c)) {
87 if (isupper(c)) {
88 name[i] = tolower(c);
89 }
90 } else {
91 name[i] = '_';
92 }
93 }
94
95 // Find the first letter and start with it.
96 int i = 0;
97 for (; i < n; ++i) {
98 if (isalpha(name[i])) break;
99 }
100
101 // Return "unknown" if none of the name's chars were letters.
102 return i == n ? "unknown" : name.substr(i);
103 }
104
UniquifyHelper(const string & name) const105 string NodeNameMapping::UniquifyHelper(const string& name) const {
106 // If the name hasn't been used yet, use it as-is.
107 if (used_names_.find(name) == used_names_.end()) return name;
108 // Add a suffix to name to make it unique.
109 for (int i = 0;; ++i) {
110 const string candidate = strings::StrCat(name, "_", i);
111 if (used_names_.find(candidate) == used_names_.end()) return candidate;
112 }
113 }
114
GetInputName(const string & name)115 string NodeNameMapping::GetInputName(const string& name) {
116 const string& input_name = GetOutputName(name);
117 name_mapping_[name] = input_name;
118 return input_name;
119 }
120
GetOutputName(const string & name)121 string NodeNameMapping::GetOutputName(const string& name) {
122 const string& input_name = UniquifyHelper(Normalize(name));
123 // Record that we used this name, but don't add it to name_mapping_
124 // since this name is not for a node.
125 used_names_.insert(input_name);
126 return input_name;
127 }
128
Uniquify(const string & name)129 string NodeNameMapping::Uniquify(const string& name) {
130 const string uniqued = UniquifyHelper(name);
131 name_mapping_[name] = uniqued;
132 used_names_.insert(uniqued);
133 return uniqued;
134 }
135
UseOutputName(const string & name)136 Status NodeNameMapping::UseOutputName(const string& name) {
137 const auto& iter = used_names_.find(name);
138 if (iter != used_names_.end()) {
139 return InvalidArgument("Cannot have duplicate output names. Name '", name,
140 "' appears more than once in 'output_names' array.");
141 }
142 used_names_.insert(iter, name);
143 return Status::OK();
144 }
145
Lookup(const string & name) const146 string NodeNameMapping::Lookup(const string& name) const {
147 const auto iter = name_mapping_.find(name);
148 if (iter == name_mapping_.end()) return string();
149 return iter->second;
150 }
151
ValidateNonRefOutput(const Node * node,int idx)152 Status ValidateNonRefOutput(const Node* node, int idx) {
153 const DataType& dt = node->output_type(idx);
154 return IsRefType(dt)
155 ? InvalidArgument("Output ", idx, " of node '", node->name(),
156 "' has a reference type ", DataTypeString(dt))
157 : Status::OK();
158 }
159
FillFunctionBody(const string & fn_name,const NodeNameMapping & node_names,const std::vector<const Node * > & body_nodes,const std::unordered_map<string,string> & tensor_renaming,FunctionDef * fdef)160 Status FillFunctionBody(
161 const string& fn_name, const NodeNameMapping& node_names,
162 const std::vector<const Node*>& body_nodes,
163 const std::unordered_map<string, string>& tensor_renaming,
164 FunctionDef* fdef) {
165 std::unordered_set<string> func_attr_names;
166 for (const auto& func_attr : fdef->signature().attr()) {
167 func_attr_names.insert(func_attr.name());
168 }
169
170 std::vector<const Edge*> in_edges;
171 std::vector<const Edge*> control_edges;
172 for (const Node* node : body_nodes) {
173 NodeDef* node_def = fdef->add_node_def();
174 // First, copy the node_def as is. We will patch it next.
175 *node_def = node->def();
176 if (!node->assigned_device_name().empty()) {
177 node_def->set_device(node->assigned_device_name());
178 }
179 node_def->set_name(node_names.Lookup(node->name()));
180
181 // Input names must be set based on nested names in tensor_renaming.
182 // Clear the flat input names we got from the original node_def
183 // from the graph.
184 node_def->clear_input();
185
186 // Collect regular and control inputs. Regular inputs are indexed
187 // by the index at which they come into the `node`. Control inputs
188 // don't follow any order.
189 in_edges.clear();
190 in_edges.resize(node->num_inputs(), nullptr);
191 control_edges.clear();
192 for (const Edge* edge : node->in_edges()) {
193 if (edge->src()->IsSource()) continue;
194 if (edge->IsControlEdge()) {
195 control_edges.push_back(edge);
196 } else {
197 in_edges[edge->dst_input()] = edge;
198 }
199 }
200
201 // Add regular inputs.
202 for (size_t i = 0; i < in_edges.size(); ++i) {
203 const Edge* edge = in_edges[i];
204 string original_input_name;
205 if (edge == nullptr) {
206 // A backedge might not appear as a regular Edge, but be only present
207 // in the node_def. Such edges are referred to as requested_inputs().
208 if (i >= node->requested_inputs().size()) {
209 return InvalidArgument(
210 "Graph to be converted to function appears to be malformed. ",
211 "Node ", node->name(), " is missing input edge ", i);
212 }
213 original_input_name =
214 ParseTensorName(node->requested_inputs()[i]).ToString();
215 } else {
216 original_input_name =
217 strings::StrCat(edge->src()->name(), ":", edge->src_output());
218 }
219
220 const auto iter = tensor_renaming.find(original_input_name);
221 if (iter == tensor_renaming.end()) {
222 return InvalidArgument(
223 "Input ", i, ", '", original_input_name, "', of node '",
224 node->name(), "' in function '", fn_name,
225 "' is not available. You might need to include it in inputs "
226 "or include its source node in the body");
227 }
228 node_def->add_input(iter->second);
229 }
230
231 // Add control inputs.
232 for (const Edge* edge : control_edges) {
233 // Add this control input only if the src node is in the body or a part of
234 // the inputs.
235 const string normalized = node_names.Lookup(edge->src()->name());
236 // If we did not find a name for the source of control edge, this
237 // source must be outside of the body, and not an input. Raise an error.
238 if (normalized.empty()) {
239 return InvalidArgument(
240 "The source of control edge ", edge->DebugString(),
241 " is not in the body. Encountered while creating function '",
242 fn_name, "'");
243 }
244 node_def->add_input(strings::StrCat("^", normalized));
245 }
246
247 // A function is stateful if any of its nodes are stateful.
248 if (node->op_def().is_stateful()) {
249 fdef->mutable_signature()->set_is_stateful(true);
250 }
251
252 // If this node has any attributes with placeholder value, add the
253 // attribute to FunctionDef signature.
254 for (const auto& iter : node->attrs()) {
255 if (iter.second.placeholder().empty()) {
256 continue;
257 }
258
259 // If we already added the attribute, skip it.
260 string func_attr_name = iter.second.placeholder();
261 if (func_attr_names.find(func_attr_name) != func_attr_names.end()) {
262 continue;
263 }
264
265 // This node's attribute is a placeholder value, so it does not have type
266 // information. We check node's OpDef for attribute type.
267 string node_attr_name = iter.first;
268 const OpDef::AttrDef* node_attr_def = nullptr;
269 for (const auto& node_attr : node->op_def().attr()) {
270 if (node_attr.name() == node_attr_name) {
271 node_attr_def = &node_attr;
272 }
273 }
274 if (!node_attr_def) {
275 #ifdef TENSORFLOW_LITE_PROTOS
276 return errors::Unimplemented(
277 "Placeholder value is not supported for attributes not in OpDef. "
278 "Attribute: ",
279 node_attr_name);
280 #else
281 return errors::Unimplemented(
282 "Placeholder value is not supported for attributes not in OpDef. "
283 "Attribute: ",
284 node_attr_name, ", OpDef: ", node->op_def().DebugString());
285 #endif
286 }
287 OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr();
288 attr_def->set_name(func_attr_name);
289 attr_def->set_type(node_attr_def->type());
290
291 func_attr_names.insert(func_attr_name);
292 }
293 }
294 return Status::OK();
295 }
296
297 // Graph to FunctionDef conversion. This code is closely modeled on the Python
298 // code in tensorflow/python/framework/function.py.
GraphToFunctionDef(const Graph & fn_body,const string & fn_name,bool append_hash_to_fn_name,const std::vector<const Node * > & body_nodes,const std::vector<OutputTensor> & inputs,const std::vector<OutputTensor> & outputs,const std::vector<string> & output_names,const std::vector<const Node * > & control_outputs,const std::vector<string> & control_output_names,const char * description,FunctionDef * fdef)299 Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
300 bool append_hash_to_fn_name,
301 const std::vector<const Node*>& body_nodes,
302 const std::vector<OutputTensor>& inputs,
303 const std::vector<OutputTensor>& outputs,
304 const std::vector<string>& output_names,
305 const std::vector<const Node*>& control_outputs,
306 const std::vector<string>& control_output_names,
307 const char* description, FunctionDef* fdef) {
308 if (!output_names.empty()) {
309 DCHECK_EQ(output_names.size(), outputs.size());
310 }
311
312 if (description != nullptr) {
313 fdef->mutable_signature()->set_description(description);
314 }
315
316 // Keep track of names we used and how we normalized them.
317 NodeNameMapping node_names;
318
319 // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
320 // name we used in the function:
321 // - For input tensors:
322 // {flat_tensor_name -> normalized_name_of_src_node}
323 // e.g. {In:3 -> in}
324 // - For tensors produced by nodes in function's body:
325 // {flat_tensor_name -> nested_tensor_name}
326 // e.g. {Add:3 -> add_0:z:1}
327 std::unordered_map<string, string> tensor_renaming;
328
329 // Fill outputs in function's signature.
330 // We fill the outputs first to prevent output_names from colliding
331 // with the input names we pick below. With this order, no names are used in
332 // node_names yet, and output_names won't collide with anything (except
333 // potentially with themselves).
334 for (size_t i = 0; i < outputs.size(); ++i) {
335 const Node* node = outputs[i].node;
336 int idx = outputs[i].index;
337 OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
338 argdef->set_type(node->output_type(idx));
339 if (!output_names.empty()) {
340 TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
341 argdef->set_name(output_names[i]);
342 } else {
343 argdef->set_name(node_names.GetOutputName(node->name()));
344 }
345 }
346
347 // Fill inputs in function's signature.
348 for (size_t i = 0; i < inputs.size(); ++i) {
349 const Node* node = inputs[i].node;
350 int idx = inputs[i].index;
351 OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
352 argdef->set_type(node->output_type(idx));
353 const string& input_name = node_names.GetInputName(node->name());
354 argdef->set_name(input_name);
355 tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
356 }
357
358 // Populate tensor_renaming and node_names.
359 // Generate the new output names for every node in the function.
360 // The NodeDefs in FunctionDefs use a different naming scheme for
361 // their inputs than the NodeDefs in a graph (see the comment for
362 // FunctionDef.node_def in function.proto). We do the
363 // graph tensor name -> function tensor name conversion for every
364 // possible input (i.e. every node's outputs) and store the result
365 // in tensor_renaming.
366 for (const Node* node : body_nodes) {
367 // Make sure node_name does not collide with an input or output name.
368 const string& node_name = node_names.Uniquify(node->name());
369 // For each output_arg in the op_def, the output_ranges
370 // map will have [start, end] range of indices that this arg produces
371 // among all the output tensors of this op.
372 NameRangeMap output_ranges;
373 TF_RETURN_IF_ERROR(
374 NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
375 for (const auto& output : output_ranges) {
376 const StringPiece& output_name = output.first;
377 int index_start = output.second.first;
378 int index_end = output.second.second;
379 for (int i = index_start; i < index_end; ++i) {
380 const string& original_name = strings::StrCat(node->name(), ":", i);
381 const string& new_name =
382 strings::StrCat(node_name, ":", output_name, ":", i - index_start);
383 // Record the mapping if this tensor is not already mapped.
384 // Tensor can be already mapped if it is used as an input.
385 if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
386 tensor_renaming[original_name] = new_name;
387 }
388 }
389 }
390 }
391
392 TF_RETURN_IF_ERROR(
393 FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
394
395 // Remap return values.
396 for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
397 const string& ret_name = fdef->signature().output_arg(r).name();
398 // We convert this flat tensor name to the nested value
399 // (e.g. `add:z:1`) that we stored in tensor_renaming.
400 const string& return_value =
401 strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
402 const auto iter = tensor_renaming.find(return_value);
403 if (iter == tensor_renaming.end()) {
404 return InvalidArgument(
405 "TF_Output ", return_value, " is neither in the function body ",
406 "nor among function inputs. Encountered while creating function '",
407 fn_name, "'");
408 }
409 (*fdef->mutable_ret())[ret_name] = iter->second;
410 }
411
412 if (append_hash_to_fn_name) {
413 const uint64 hash = FunctionDefHash(*fdef);
414 string encoded;
415 TF_RETURN_IF_ERROR(Base64Encode(
416 StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
417 &encoded));
418 // Besides letters and digits our Base64 encoding uses '_' and '-'.
419 // Dash is invalid in operation names and multiple underscores in random
420 // places look strange. Since we never need to decode the hash back,
421 // replace these chars with with 'a' and 'A'. Replacing with different
422 // letters keeps more entropy.
423 std::replace(encoded.begin(), encoded.end(), '-', 'a');
424 std::replace(encoded.begin(), encoded.end(), '_', 'A');
425 fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
426 } else {
427 fdef->mutable_signature()->set_name(fn_name);
428 }
429
430 if (!control_output_names.empty() &&
431 (control_outputs.size() != control_output_names.size())) {
432 return InvalidArgument(
433 "Expected number of control outputs (", control_outputs.size(),
434 ") and the number of control output names (",
435 control_output_names.size(), ") to match but they do not.");
436 }
437 std::unordered_set<string> control_output_names_set;
438 for (int i = 0; i < control_outputs.size(); ++i) {
439 string signature_name;
440 if (!control_output_names.empty()) {
441 signature_name = control_output_names[i];
442 } else {
443 signature_name = control_outputs[i]->name();
444 }
445 if (!control_output_names_set.insert(signature_name).second) {
446 return errors::InvalidArgument("Repeated control output name: ",
447 signature_name);
448 }
449 fdef->mutable_signature()->add_control_output(signature_name);
450 (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name();
451 }
452
453 return Status::OK();
454 }
455
456 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
457 // does various checks while doing so. `input_nodes` will contain the same
458 // information as input_tensors just in a different structure to make
459 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)460 Status ProcessInputs(
461 const TF_Graph* fn_body, const char* fn_name, int ninputs,
462 const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
463 std::unordered_map<const Node*, std::vector<int>>* input_nodes)
464 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
465 input_tensors->reserve(ninputs);
466 for (int i = 0; i < ninputs; ++i) {
467 Node* node = &inputs[i].oper->node;
468 int idx = inputs[i].index;
469
470 TF_RETURN_WITH_CONTEXT_IF_ERROR(
471 fn_body->graph.IsValidOutputTensor(node, idx),
472 "Encountered while processing input ", i, " into function '", fn_name,
473 "'");
474 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
475 "Encountered while processing input ", i,
476 " into function '", fn_name, "'");
477
478 input_tensors->emplace_back(node, idx);
479
480 const auto& iter = input_nodes->find(node);
481 if (iter == input_nodes->end()) {
482 input_nodes->insert({node, {idx}});
483 } else {
484 auto& indices = iter->second;
485 if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
486 return InvalidArgument("TF_Output ", node->name(), ":", idx,
487 " appears more than once in the input list");
488 }
489 indices.push_back(idx);
490 }
491 }
492 return Status::OK();
493 }
494
495 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
496 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)497 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
498 int noutputs, const TF_Output* outputs,
499 std::vector<OutputTensor>* output_tensors)
500 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
501 output_tensors->reserve(noutputs);
502 for (int i = 0; i < noutputs; ++i) {
503 Node* node = &outputs[i].oper->node;
504 int idx = outputs[i].index;
505 TF_RETURN_WITH_CONTEXT_IF_ERROR(
506 fn_body->graph.IsValidOutputTensor(node, idx),
507 "Encountered while processing output ", i, " from function '", fn_name,
508 "'");
509 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
510 "Encountered while creating function '",
511 fn_name, "'");
512 output_tensors->emplace_back(node, idx);
513 }
514 return Status::OK();
515 }
516
517 // Populates `body_nodes` with the nodes that will become function's body.
518 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)519 Status ComputeBodyNodes(
520 const TF_Graph* fn_body, const char* fn_name, int num_opers,
521 const TF_Operation* const* opers,
522 const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
523 std::vector<const Node*>* body_nodes)
524 EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
525 if (num_opers == -1) {
526 for (const Node* node : fn_body->graph.op_nodes()) {
527 const auto& iter = input_nodes.find(node);
528 if (iter == input_nodes.end()) {
529 // This node is not referenced in inputs. Add it to the body.
530 body_nodes->push_back(node);
531 } else {
532 // This node is referenced in inputs. Currently, we place an
533 // artificial restriction and require that when num_opers=-1, such
534 // nodes must have a single output.
535 if (node->num_outputs() != 1) {
536 return InvalidArgument(
537 "When `num_opers` is set to -1, nodes referenced in `inputs` "
538 "must have a single output. Node ",
539 node->name(), " has ", node->num_outputs(),
540 " outputs. Encountered while creating function '", fn_name, "'");
541 }
542 }
543 }
544 } else {
545 body_nodes->reserve(num_opers);
546 for (int i = 0; i < num_opers; ++i) {
547 const Node* node = &opers[i]->node;
548 body_nodes->push_back(node);
549 }
550 }
551 return Status::OK();
552 }
553
554 } // namespace
555 } // namespace tensorflow
556
557 using tensorflow::Node;
558 using tensorflow::string;
559
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)560 TF_Function* TF_GraphToFunctionWithControlOutputs(
561 const TF_Graph* fn_body, const char* fn_name,
562 unsigned char append_hash_to_fn_name, int num_opers,
563 const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
564 int noutputs, const TF_Output* outputs, const char* const* output_names,
565 int ncontrol_outputs, const TF_Operation* const* control_outputs,
566 const char* const* control_output_names, const TF_FunctionOptions* opts,
567 const char* description, TF_Status* status) {
568 tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
569
570 // Process inputs.
571 std::vector<tensorflow::OutputTensor> input_tensors;
572 std::unordered_map<const Node*, std::vector<int>> input_nodes;
573 status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
574 &input_tensors, &input_nodes);
575 if (TF_GetCode(status) != TF_OK) return nullptr;
576
577 // Process outputs.
578 std::vector<tensorflow::OutputTensor> output_tensors;
579 status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
580 outputs, &output_tensors);
581 if (TF_GetCode(status) != TF_OK) return nullptr;
582
583 // Process output names.
584 std::vector<string> output_names_vec;
585 if (output_names) {
586 output_names_vec.reserve(noutputs);
587 for (int i = 0; i < noutputs; ++i) {
588 output_names_vec.push_back(string(output_names[i]));
589 }
590 }
591
592 // Process control output names.
593 std::vector<string> control_output_names_vec;
594 if (control_output_names) {
595 control_output_names_vec.reserve(ncontrol_outputs);
596 for (int i = 0; i < ncontrol_outputs; ++i) {
597 control_output_names_vec.push_back(string(output_names[i]));
598 }
599 }
600
601 // Compute body nodes.
602 std::vector<const Node*> body_nodes;
603 status->status = tensorflow::ComputeBodyNodes(
604 fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
605 if (TF_GetCode(status) != TF_OK) return nullptr;
606
607 // Compute body nodes.
608 std::vector<const Node*> control_output_nodes;
609 for (int i = 0; i < ncontrol_outputs; ++i) {
610 control_output_nodes.push_back(&control_outputs[i]->node);
611 }
612
613 // Do the actual function creation.
614 TF_Function* tf_function = new TF_Function();
615 DCHECK(append_hash_to_fn_name <= 1);
616 status->status = tensorflow::GraphToFunctionDef(
617 fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
618 input_tensors, output_tensors, output_names_vec, control_output_nodes,
619 control_output_names_vec, description, &tf_function->fdef);
620 if (TF_GetCode(status) != TF_OK) {
621 TF_DeleteFunction(tf_function);
622 return nullptr;
623 }
624 return tf_function;
625 }
626
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)627 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
628 unsigned char append_hash_to_fn_name,
629 int num_opers, const TF_Operation* const* opers,
630 int ninputs, const TF_Output* inputs,
631 int noutputs, const TF_Output* outputs,
632 const char* const* output_names,
633 const TF_FunctionOptions* opts,
634 const char* description, TF_Status* status) {
635 return TF_GraphToFunctionWithControlOutputs(
636 fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
637 inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
638 description, status);
639 }
640
TF_FunctionName(TF_Function * func)641 const char* TF_FunctionName(TF_Function* func) {
642 return func->fdef.signature().name().c_str();
643 }
644
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)645 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
646 const TF_Function* grad, TF_Status* status) {
647 if (func == nullptr) {
648 status->status = InvalidArgument(
649 "'func' argument to TF_GraphCopyFunction cannot be null");
650 return;
651 }
652
653 // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
654 // to avoid the extra copy here.
655 tensorflow::FunctionDefLibrary fdef_lib;
656 *fdef_lib.add_function() = func->fdef;
657 if (grad) {
658 *fdef_lib.add_function() = grad->fdef;
659 tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
660 gdef->set_function_name(func->fdef.signature().name());
661 gdef->set_gradient_func(grad->fdef.signature().name());
662 }
663
664 tensorflow::mutex_lock l(g->mu);
665 status->status = g->graph.AddFunctionLibrary(fdef_lib);
666 }
667
TF_GraphNumFunctions(TF_Graph * g)668 int TF_GraphNumFunctions(TF_Graph* g) {
669 tensorflow::mutex_lock l(g->mu);
670 return g->graph.flib_def().num_functions();
671 }
672
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)673 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
674 TF_Status* status) {
675 tensorflow::FunctionDefLibrary lib;
676 {
677 tensorflow::mutex_lock l(g->mu);
678 lib = g->graph.flib_def().ToProto();
679 }
680 const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
681 for (int i = 0; i < len; ++i) {
682 TF_Function* func = new TF_Function();
683 func->fdef = lib.function(i);
684 funcs[i] = func;
685 }
686 status->status = tensorflow::Status::OK();
687 return len;
688 }
689
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)690 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
691 TF_Status* status) {
692 status->status = MessageToBuffer(func->fdef, output_func_def);
693 }
694
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)695 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
696 TF_Status* status) {
697 TF_Function* func = new TF_Function();
698 if (!func->fdef.ParseFromArray(proto, proto_len)) {
699 status->status = InvalidArgument(
700 "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
701 TF_DeleteFunction(func);
702 return nullptr;
703 }
704 status->status = tensorflow::Status::OK();
705 return func;
706 }
707
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)708 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
709 const void* proto, size_t proto_len,
710 TF_Status* status) {
711 tensorflow::AttrValue attr_value;
712 if (!attr_value.ParseFromArray(proto, proto_len)) {
713 status->status = InvalidArgument(
714 "Unparseable AttrValue proto passed to "
715 "TF_FunctionSetAttrValueProto");
716 return;
717 }
718 (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
719 status->status = tensorflow::Status::OK();
720 }
721
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)722 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
723 TF_Buffer* output_attr_value,
724 TF_Status* status) {
725 const auto& it = func->fdef.attr().find(attr_name);
726 if (it == func->fdef.attr().end()) {
727 status->status =
728 InvalidArgument("Function '", func->fdef.signature().name(),
729 "' has no attr named '", attr_name, "'.");
730 return;
731 }
732 status->status = MessageToBuffer(it->second, output_attr_value);
733 }
734
TF_DeleteFunction(TF_Function * func)735 void TF_DeleteFunction(TF_Function* func) { delete func; }
736