• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api_internal.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/function.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/tensor.pb.h"  // NOLINT
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/lib/strings/base64.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 
32 using tensorflow::errors::InvalidArgument;
33 
34 namespace tensorflow {
35 namespace {
36 
37 // Class that maintains a one-to-one original node name -> new node name
38 // mapping. We normalize the names used as input and output arguments to match
39 // regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name.
40 // Once we rename them, we risk creating a name collision with the other
41 // node names, so if necessary we add a suffix to make
42 // names unique. If we have an input named "A" and a node in the function
43 // body named "a", they will be renamed to "a" and "a_0".
44 class NodeNameMapping {
45  public:
46   NodeNameMapping() = default;
47 
48   // Normalize the input name and make it unique. This is the same as the
49   // function for output, expect that it adds a name mapping for the name.
50   string GetInputName(const string& name);
51 
52   // Normalize the output name and make it unique.
53   string GetOutputName(const string& name);
54 
55   // Make the node name unique.
56   string Uniquify(const string& name);
57 
58   // Records name as a used name. If this name is already used,
59   // returns an error status.
60   Status UseOutputName(const string& name);
61 
62   // Look up how a node name was previously normalized/uniquified.
63   // Returns empty if name was never seen.
64   string Lookup(const string& name) const;
65 
66  private:
67   string UniquifyHelper(const string& name) const;
68   static string Normalize(string name);
69 
70   // The normalized/uniquified names already used as
71   // input names (in signature), output names (in signature), and node names
72   // (in node_def).
73   // This is a superset of values in name_mapping_.
74   std::unordered_set<string> used_names_;
75   // Mapping from original node name from the graph to the normalized
76   // and uniquified version of it.
77   std::unordered_map<string, string> name_mapping_;
78 };
79 
Normalize(string name)80 string NodeNameMapping::Normalize(string name) {
81   // Convert letters to lowercase and non-alphanumeric characters to '_'.
82   if (name.empty()) return "unknown";
83   const int n = name.size();
84   for (int i = 0; i < n; ++i) {
85     char c = name[i];
86     if (isalnum(c)) {
87       if (isupper(c)) {
88         name[i] = tolower(c);
89       }
90     } else {
91       name[i] = '_';
92     }
93   }
94 
95   // Find the first letter and start with it.
96   int i = 0;
97   for (; i < n; ++i) {
98     if (isalpha(name[i])) break;
99   }
100 
101   // Return "unknown" if none of the name's chars were letters.
102   return i == n ? "unknown" : name.substr(i);
103 }
104 
UniquifyHelper(const string & name) const105 string NodeNameMapping::UniquifyHelper(const string& name) const {
106   // If the name hasn't been used yet, use it as-is.
107   if (used_names_.find(name) == used_names_.end()) return name;
108   // Add a suffix to name to make it unique.
109   for (int i = 0;; ++i) {
110     const string candidate = strings::StrCat(name, "_", i);
111     if (used_names_.find(candidate) == used_names_.end()) return candidate;
112   }
113 }
114 
GetInputName(const string & name)115 string NodeNameMapping::GetInputName(const string& name) {
116   const string& input_name = GetOutputName(name);
117   name_mapping_[name] = input_name;
118   return input_name;
119 }
120 
GetOutputName(const string & name)121 string NodeNameMapping::GetOutputName(const string& name) {
122   const string& input_name = UniquifyHelper(Normalize(name));
123   // Record that we used this name, but don't add it to name_mapping_
124   // since this name is not for a node.
125   used_names_.insert(input_name);
126   return input_name;
127 }
128 
Uniquify(const string & name)129 string NodeNameMapping::Uniquify(const string& name) {
130   const string uniqued = UniquifyHelper(name);
131   name_mapping_[name] = uniqued;
132   used_names_.insert(uniqued);
133   return uniqued;
134 }
135 
UseOutputName(const string & name)136 Status NodeNameMapping::UseOutputName(const string& name) {
137   const auto& iter = used_names_.find(name);
138   if (iter != used_names_.end()) {
139     return InvalidArgument("Cannot have duplicate output names. Name '", name,
140                            "' appears more than once in 'output_names' array.");
141   }
142   used_names_.insert(iter, name);
143   return Status::OK();
144 }
145 
Lookup(const string & name) const146 string NodeNameMapping::Lookup(const string& name) const {
147   const auto iter = name_mapping_.find(name);
148   if (iter == name_mapping_.end()) return string();
149   return iter->second;
150 }
151 
ValidateNonRefOutput(const Node * node,int idx)152 Status ValidateNonRefOutput(const Node* node, int idx) {
153   const DataType& dt = node->output_type(idx);
154   return IsRefType(dt)
155              ? InvalidArgument("Output ", idx, " of node '", node->name(),
156                                "' has a reference type ", DataTypeString(dt))
157              : Status::OK();
158 }
159 
FillFunctionBody(const string & fn_name,const NodeNameMapping & node_names,const std::vector<const Node * > & body_nodes,const std::unordered_map<string,string> & tensor_renaming,FunctionDef * fdef)160 Status FillFunctionBody(
161     const string& fn_name, const NodeNameMapping& node_names,
162     const std::vector<const Node*>& body_nodes,
163     const std::unordered_map<string, string>& tensor_renaming,
164     FunctionDef* fdef) {
165   std::unordered_set<string> func_attr_names;
166   for (const auto& func_attr : fdef->signature().attr()) {
167     func_attr_names.insert(func_attr.name());
168   }
169 
170   std::vector<const Edge*> in_edges;
171   std::vector<const Edge*> control_edges;
172   for (const Node* node : body_nodes) {
173     NodeDef* node_def = fdef->add_node_def();
174     // First, copy the node_def as is. We will patch it next.
175     *node_def = node->def();
176     if (!node->assigned_device_name().empty()) {
177       node_def->set_device(node->assigned_device_name());
178     }
179     node_def->set_name(node_names.Lookup(node->name()));
180 
181     // Input names must be set based on nested names in tensor_renaming.
182     // Clear the flat input names we got from the original node_def
183     // from the graph.
184     node_def->clear_input();
185 
186     // Collect regular and control inputs. Regular inputs are indexed
187     // by the index at which they come into the `node`. Control inputs
188     // don't follow any order.
189     in_edges.clear();
190     in_edges.resize(node->num_inputs(), nullptr);
191     control_edges.clear();
192     for (const Edge* edge : node->in_edges()) {
193       if (edge->src()->IsSource()) continue;
194       if (edge->IsControlEdge()) {
195         control_edges.push_back(edge);
196       } else {
197         in_edges[edge->dst_input()] = edge;
198       }
199     }
200 
201     // Add regular inputs.
202     for (size_t i = 0; i < in_edges.size(); ++i) {
203       const Edge* edge = in_edges[i];
204       string original_input_name;
205       if (edge == nullptr) {
206         // A backedge might not appear as a regular Edge, but be only present
207         // in the node_def. Such edges are referred to as requested_inputs().
208         if (i >= node->requested_inputs().size()) {
209           return InvalidArgument(
210               "Graph to be converted to function appears to be malformed. ",
211               "Node ", node->name(), " is missing input edge ", i);
212         }
213         original_input_name =
214             ParseTensorName(node->requested_inputs()[i]).ToString();
215       } else {
216         original_input_name =
217             strings::StrCat(edge->src()->name(), ":", edge->src_output());
218       }
219 
220       const auto iter = tensor_renaming.find(original_input_name);
221       if (iter == tensor_renaming.end()) {
222         return InvalidArgument(
223             "Input ", i, ", '", original_input_name, "', of node '",
224             node->name(), "' in function '", fn_name,
225             "' is not available. You might need to include it in inputs "
226             "or include its source node in the body");
227       }
228       node_def->add_input(iter->second);
229     }
230 
231     // Add control inputs.
232     for (const Edge* edge : control_edges) {
233       // Add this control input only if the src node is in the body or a part of
234       // the inputs.
235       const string normalized = node_names.Lookup(edge->src()->name());
236       // If we did not find a name for the source of control edge, this
237       // source must be outside of the body, and not an input. Raise an error.
238       if (normalized.empty()) {
239         return InvalidArgument(
240             "The source of control edge ", edge->DebugString(),
241             " is not in the body. Encountered while creating function '",
242             fn_name, "'");
243       }
244       node_def->add_input(strings::StrCat("^", normalized));
245     }
246 
247     // A function is stateful if any of its nodes are stateful.
248     if (node->op_def().is_stateful()) {
249       fdef->mutable_signature()->set_is_stateful(true);
250     }
251 
252     // If this node has any attributes with placeholder value, add the
253     // attribute to FunctionDef signature.
254     for (const auto& iter : node->attrs()) {
255       if (iter.second.placeholder().empty()) {
256         continue;
257       }
258 
259       // If we already added the attribute, skip it.
260       string func_attr_name = iter.second.placeholder();
261       if (func_attr_names.find(func_attr_name) != func_attr_names.end()) {
262         continue;
263       }
264 
265       // This node's attribute is a placeholder value, so it does not have type
266       // information. We check node's OpDef for attribute type.
267       string node_attr_name = iter.first;
268       const OpDef::AttrDef* node_attr_def = nullptr;
269       for (const auto& node_attr : node->op_def().attr()) {
270         if (node_attr.name() == node_attr_name) {
271           node_attr_def = &node_attr;
272         }
273       }
274       if (!node_attr_def) {
275 #ifdef TENSORFLOW_LITE_PROTOS
276         return errors::Unimplemented(
277             "Placeholder value is not supported for attributes not in OpDef. "
278             "Attribute: ",
279             node_attr_name);
280 #else
281         return errors::Unimplemented(
282             "Placeholder value is not supported for attributes not in OpDef. "
283             "Attribute: ",
284             node_attr_name, ", OpDef: ", node->op_def().DebugString());
285 #endif
286       }
287       OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr();
288       attr_def->set_name(func_attr_name);
289       attr_def->set_type(node_attr_def->type());
290 
291       func_attr_names.insert(func_attr_name);
292     }
293   }
294   return Status::OK();
295 }
296 
297 // Graph to FunctionDef conversion. This code is closely modeled on the Python
298 // code in tensorflow/python/framework/function.py.
GraphToFunctionDef(const Graph & fn_body,const string & fn_name,bool append_hash_to_fn_name,const std::vector<const Node * > & body_nodes,const std::vector<OutputTensor> & inputs,const std::vector<OutputTensor> & outputs,const std::vector<string> & output_names,const std::vector<const Node * > & control_outputs,const std::vector<string> & control_output_names,const char * description,FunctionDef * fdef)299 Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name,
300                           bool append_hash_to_fn_name,
301                           const std::vector<const Node*>& body_nodes,
302                           const std::vector<OutputTensor>& inputs,
303                           const std::vector<OutputTensor>& outputs,
304                           const std::vector<string>& output_names,
305                           const std::vector<const Node*>& control_outputs,
306                           const std::vector<string>& control_output_names,
307                           const char* description, FunctionDef* fdef) {
308   if (!output_names.empty()) {
309     DCHECK_EQ(output_names.size(), outputs.size());
310   }
311 
312   if (description != nullptr) {
313     fdef->mutable_signature()->set_description(description);
314   }
315 
316   // Keep track of names we used and how we normalized them.
317   NodeNameMapping node_names;
318 
319   // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the
320   // name we used in the function:
321   //  - For input tensors:
322   //    {flat_tensor_name -> normalized_name_of_src_node}
323   //    e.g. {In:3 -> in}
324   //  - For tensors produced by nodes in function's body:
325   //    {flat_tensor_name -> nested_tensor_name}
326   //    e.g. {Add:3 -> add_0:z:1}
327   std::unordered_map<string, string> tensor_renaming;
328 
329   // Fill outputs in function's signature.
330   // We fill the outputs first to prevent output_names from colliding
331   // with the input names we pick below. With this order, no names are used in
332   // node_names yet, and output_names won't collide with anything (except
333   // potentially with themselves).
334   for (size_t i = 0; i < outputs.size(); ++i) {
335     const Node* node = outputs[i].node;
336     int idx = outputs[i].index;
337     OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg();
338     argdef->set_type(node->output_type(idx));
339     if (!output_names.empty()) {
340       TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i]));
341       argdef->set_name(output_names[i]);
342     } else {
343       argdef->set_name(node_names.GetOutputName(node->name()));
344     }
345   }
346 
347   // Fill inputs in function's signature.
348   for (size_t i = 0; i < inputs.size(); ++i) {
349     const Node* node = inputs[i].node;
350     int idx = inputs[i].index;
351     OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg();
352     argdef->set_type(node->output_type(idx));
353     const string& input_name = node_names.GetInputName(node->name());
354     argdef->set_name(input_name);
355     tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name;
356   }
357 
358   // Populate tensor_renaming and node_names.
359   // Generate the new output names for every node in the function.
360   // The NodeDefs in FunctionDefs use a different naming scheme for
361   // their inputs than the NodeDefs in a graph (see the comment for
362   // FunctionDef.node_def in function.proto). We do the
363   // graph tensor name -> function tensor name conversion for every
364   // possible input (i.e. every node's outputs) and store the result
365   // in tensor_renaming.
366   for (const Node* node : body_nodes) {
367     // Make sure node_name does not collide with an input or output name.
368     const string& node_name = node_names.Uniquify(node->name());
369     // For each output_arg in the op_def, the output_ranges
370     // map will have [start, end] range of indices that this arg produces
371     // among all the output tensors of this op.
372     NameRangeMap output_ranges;
373     TF_RETURN_IF_ERROR(
374         NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
375     for (const auto& output : output_ranges) {
376       const StringPiece& output_name = output.first;
377       int index_start = output.second.first;
378       int index_end = output.second.second;
379       for (int i = index_start; i < index_end; ++i) {
380         const string& original_name = strings::StrCat(node->name(), ":", i);
381         const string& new_name =
382             strings::StrCat(node_name, ":", output_name, ":", i - index_start);
383         // Record the mapping if this tensor is not already mapped.
384         // Tensor can be already mapped if it is used as an input.
385         if (tensor_renaming.find(original_name) == tensor_renaming.end()) {
386           tensor_renaming[original_name] = new_name;
387         }
388       }
389     }
390   }
391 
392   TF_RETURN_IF_ERROR(
393       FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef));
394 
395   // Remap return values.
396   for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
397     const string& ret_name = fdef->signature().output_arg(r).name();
398     // We convert this flat tensor name to the nested value
399     // (e.g. `add:z:1`) that we stored in tensor_renaming.
400     const string& return_value =
401         strings::StrCat(outputs[r].node->name(), ":", outputs[r].index);
402     const auto iter = tensor_renaming.find(return_value);
403     if (iter == tensor_renaming.end()) {
404       return InvalidArgument(
405           "TF_Output ", return_value, " is neither in the function body ",
406           "nor among function inputs. Encountered while creating function '",
407           fn_name, "'");
408     }
409     (*fdef->mutable_ret())[ret_name] = iter->second;
410   }
411 
412   if (append_hash_to_fn_name) {
413     const uint64 hash = FunctionDefHash(*fdef);
414     string encoded;
415     TF_RETURN_IF_ERROR(Base64Encode(
416         StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)),
417         &encoded));
418     // Besides letters and digits our Base64 encoding uses '_' and '-'.
419     // Dash is invalid in operation names and multiple underscores in random
420     // places look strange. Since we never need to decode the hash back,
421     // replace these chars with with 'a' and 'A'. Replacing with different
422     // letters keeps more entropy.
423     std::replace(encoded.begin(), encoded.end(), '-', 'a');
424     std::replace(encoded.begin(), encoded.end(), '_', 'A');
425     fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded));
426   } else {
427     fdef->mutable_signature()->set_name(fn_name);
428   }
429 
430   if (!control_output_names.empty() &&
431       (control_outputs.size() != control_output_names.size())) {
432     return InvalidArgument(
433         "Expected number of control outputs (", control_outputs.size(),
434         ") and the number of control output names (",
435         control_output_names.size(), ") to match but they do not.");
436   }
437   std::unordered_set<string> control_output_names_set;
438   for (int i = 0; i < control_outputs.size(); ++i) {
439     string signature_name;
440     if (!control_output_names.empty()) {
441       signature_name = control_output_names[i];
442     } else {
443       signature_name = control_outputs[i]->name();
444     }
445     if (!control_output_names_set.insert(signature_name).second) {
446       return errors::InvalidArgument("Repeated control output name: ",
447                                      signature_name);
448     }
449     fdef->mutable_signature()->add_control_output(signature_name);
450     (*fdef->mutable_control_ret())[signature_name] = control_outputs[i]->name();
451   }
452 
453   return Status::OK();
454 }
455 
456 // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
457 // does various checks while doing so. `input_nodes` will contain the same
458 // information as input_tensors just in a different structure to make
459 // following processing easier. TODO(iga): Simplify this nested structure.
ProcessInputs(const TF_Graph * fn_body,const char * fn_name,int ninputs,const TF_Output * inputs,std::vector<OutputTensor> * input_tensors,std::unordered_map<const Node *,std::vector<int>> * input_nodes)460 Status ProcessInputs(
461     const TF_Graph* fn_body, const char* fn_name, int ninputs,
462     const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
463     std::unordered_map<const Node*, std::vector<int>>* input_nodes)
464     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
465   input_tensors->reserve(ninputs);
466   for (int i = 0; i < ninputs; ++i) {
467     Node* node = &inputs[i].oper->node;
468     int idx = inputs[i].index;
469 
470     TF_RETURN_WITH_CONTEXT_IF_ERROR(
471         fn_body->graph.IsValidOutputTensor(node, idx),
472         "Encountered while processing input ", i, " into function '", fn_name,
473         "'");
474     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
475                                     "Encountered while processing input ", i,
476                                     " into function '", fn_name, "'");
477 
478     input_tensors->emplace_back(node, idx);
479 
480     const auto& iter = input_nodes->find(node);
481     if (iter == input_nodes->end()) {
482       input_nodes->insert({node, {idx}});
483     } else {
484       auto& indices = iter->second;
485       if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
486         return InvalidArgument("TF_Output ", node->name(), ":", idx,
487                                " appears more than once in the input list");
488       }
489       indices.push_back(idx);
490     }
491   }
492   return Status::OK();
493 }
494 
495 // Converts `noutputs` and `outputs` into `outputs_tensors` and does various
496 // checks while doing so.
ProcessOutputs(const TF_Graph * fn_body,const char * fn_name,int noutputs,const TF_Output * outputs,std::vector<OutputTensor> * output_tensors)497 Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
498                       int noutputs, const TF_Output* outputs,
499                       std::vector<OutputTensor>* output_tensors)
500     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
501   output_tensors->reserve(noutputs);
502   for (int i = 0; i < noutputs; ++i) {
503     Node* node = &outputs[i].oper->node;
504     int idx = outputs[i].index;
505     TF_RETURN_WITH_CONTEXT_IF_ERROR(
506         fn_body->graph.IsValidOutputTensor(node, idx),
507         "Encountered while processing output ", i, " from function '", fn_name,
508         "'");
509     TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
510                                     "Encountered while creating function '",
511                                     fn_name, "'");
512     output_tensors->emplace_back(node, idx);
513   }
514   return Status::OK();
515 }
516 
517 // Populates `body_nodes` with the nodes that will become function's body.
518 // Performs various checks.
ComputeBodyNodes(const TF_Graph * fn_body,const char * fn_name,int num_opers,const TF_Operation * const * opers,const std::unordered_map<const Node *,std::vector<int>> & input_nodes,std::vector<const Node * > * body_nodes)519 Status ComputeBodyNodes(
520     const TF_Graph* fn_body, const char* fn_name, int num_opers,
521     const TF_Operation* const* opers,
522     const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
523     std::vector<const Node*>* body_nodes)
524     EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
525   if (num_opers == -1) {
526     for (const Node* node : fn_body->graph.op_nodes()) {
527       const auto& iter = input_nodes.find(node);
528       if (iter == input_nodes.end()) {
529         // This node is not referenced in inputs. Add it to the body.
530         body_nodes->push_back(node);
531       } else {
532         // This node is referenced in inputs. Currently, we place an
533         // artificial restriction and require that when num_opers=-1, such
534         // nodes must have a single output.
535         if (node->num_outputs() != 1) {
536           return InvalidArgument(
537               "When `num_opers` is set to -1, nodes referenced in `inputs` "
538               "must have a single output. Node ",
539               node->name(), " has ", node->num_outputs(),
540               " outputs. Encountered while creating function '", fn_name, "'");
541         }
542       }
543     }
544   } else {
545     body_nodes->reserve(num_opers);
546     for (int i = 0; i < num_opers; ++i) {
547       const Node* node = &opers[i]->node;
548       body_nodes->push_back(node);
549     }
550   }
551   return Status::OK();
552 }
553 
554 }  // namespace
555 }  // namespace tensorflow
556 
557 using tensorflow::Node;
558 using tensorflow::string;
559 
TF_GraphToFunctionWithControlOutputs(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,int ncontrol_outputs,const TF_Operation * const * control_outputs,const char * const * control_output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)560 TF_Function* TF_GraphToFunctionWithControlOutputs(
561     const TF_Graph* fn_body, const char* fn_name,
562     unsigned char append_hash_to_fn_name, int num_opers,
563     const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
564     int noutputs, const TF_Output* outputs, const char* const* output_names,
565     int ncontrol_outputs, const TF_Operation* const* control_outputs,
566     const char* const* control_output_names, const TF_FunctionOptions* opts,
567     const char* description, TF_Status* status) {
568   tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu));
569 
570   // Process inputs.
571   std::vector<tensorflow::OutputTensor> input_tensors;
572   std::unordered_map<const Node*, std::vector<int>> input_nodes;
573   status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
574                                              &input_tensors, &input_nodes);
575   if (TF_GetCode(status) != TF_OK) return nullptr;
576 
577   // Process outputs.
578   std::vector<tensorflow::OutputTensor> output_tensors;
579   status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
580                                               outputs, &output_tensors);
581   if (TF_GetCode(status) != TF_OK) return nullptr;
582 
583   // Process output names.
584   std::vector<string> output_names_vec;
585   if (output_names) {
586     output_names_vec.reserve(noutputs);
587     for (int i = 0; i < noutputs; ++i) {
588       output_names_vec.push_back(string(output_names[i]));
589     }
590   }
591 
592   // Process control output names.
593   std::vector<string> control_output_names_vec;
594   if (control_output_names) {
595     control_output_names_vec.reserve(ncontrol_outputs);
596     for (int i = 0; i < ncontrol_outputs; ++i) {
597       control_output_names_vec.push_back(string(output_names[i]));
598     }
599   }
600 
601   // Compute body nodes.
602   std::vector<const Node*> body_nodes;
603   status->status = tensorflow::ComputeBodyNodes(
604       fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
605   if (TF_GetCode(status) != TF_OK) return nullptr;
606 
607   // Compute body nodes.
608   std::vector<const Node*> control_output_nodes;
609   for (int i = 0; i < ncontrol_outputs; ++i) {
610     control_output_nodes.push_back(&control_outputs[i]->node);
611   }
612 
613   // Do the actual function creation.
614   TF_Function* tf_function = new TF_Function();
615   DCHECK(append_hash_to_fn_name <= 1);
616   status->status = tensorflow::GraphToFunctionDef(
617       fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes,
618       input_tensors, output_tensors, output_names_vec, control_output_nodes,
619       control_output_names_vec, description, &tf_function->fdef);
620   if (TF_GetCode(status) != TF_OK) {
621     TF_DeleteFunction(tf_function);
622     return nullptr;
623   }
624   return tf_function;
625 }
626 
TF_GraphToFunction(const TF_Graph * fn_body,const char * fn_name,unsigned char append_hash_to_fn_name,int num_opers,const TF_Operation * const * opers,int ninputs,const TF_Output * inputs,int noutputs,const TF_Output * outputs,const char * const * output_names,const TF_FunctionOptions * opts,const char * description,TF_Status * status)627 TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
628                                 unsigned char append_hash_to_fn_name,
629                                 int num_opers, const TF_Operation* const* opers,
630                                 int ninputs, const TF_Output* inputs,
631                                 int noutputs, const TF_Output* outputs,
632                                 const char* const* output_names,
633                                 const TF_FunctionOptions* opts,
634                                 const char* description, TF_Status* status) {
635   return TF_GraphToFunctionWithControlOutputs(
636       fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
637       inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
638       description, status);
639 }
640 
TF_FunctionName(TF_Function * func)641 const char* TF_FunctionName(TF_Function* func) {
642   return func->fdef.signature().name().c_str();
643 }
644 
TF_GraphCopyFunction(TF_Graph * g,const TF_Function * func,const TF_Function * grad,TF_Status * status)645 void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
646                           const TF_Function* grad, TF_Status* status) {
647   if (func == nullptr) {
648     status->status = InvalidArgument(
649         "'func' argument to TF_GraphCopyFunction cannot be null");
650     return;
651   }
652 
653   // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
654   // to avoid the extra copy here.
655   tensorflow::FunctionDefLibrary fdef_lib;
656   *fdef_lib.add_function() = func->fdef;
657   if (grad) {
658     *fdef_lib.add_function() = grad->fdef;
659     tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
660     gdef->set_function_name(func->fdef.signature().name());
661     gdef->set_gradient_func(grad->fdef.signature().name());
662   }
663 
664   tensorflow::mutex_lock l(g->mu);
665   status->status = g->graph.AddFunctionLibrary(fdef_lib);
666 }
667 
TF_GraphNumFunctions(TF_Graph * g)668 int TF_GraphNumFunctions(TF_Graph* g) {
669   tensorflow::mutex_lock l(g->mu);
670   return g->graph.flib_def().num_functions();
671 }
672 
TF_GraphGetFunctions(TF_Graph * g,TF_Function ** funcs,int max_func,TF_Status * status)673 int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
674                          TF_Status* status) {
675   tensorflow::FunctionDefLibrary lib;
676   {
677     tensorflow::mutex_lock l(g->mu);
678     lib = g->graph.flib_def().ToProto();
679   }
680   const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
681   for (int i = 0; i < len; ++i) {
682     TF_Function* func = new TF_Function();
683     func->fdef = lib.function(i);
684     funcs[i] = func;
685   }
686   status->status = tensorflow::Status::OK();
687   return len;
688 }
689 
TF_FunctionToFunctionDef(TF_Function * func,TF_Buffer * output_func_def,TF_Status * status)690 void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
691                               TF_Status* status) {
692   status->status = MessageToBuffer(func->fdef, output_func_def);
693 }
694 
TF_FunctionImportFunctionDef(const void * proto,size_t proto_len,TF_Status * status)695 TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
696                                           TF_Status* status) {
697   TF_Function* func = new TF_Function();
698   if (!func->fdef.ParseFromArray(proto, proto_len)) {
699     status->status = InvalidArgument(
700         "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
701     TF_DeleteFunction(func);
702     return nullptr;
703   }
704   status->status = tensorflow::Status::OK();
705   return func;
706 }
707 
TF_FunctionSetAttrValueProto(TF_Function * func,const char * attr_name,const void * proto,size_t proto_len,TF_Status * status)708 void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
709                                   const void* proto, size_t proto_len,
710                                   TF_Status* status) {
711   tensorflow::AttrValue attr_value;
712   if (!attr_value.ParseFromArray(proto, proto_len)) {
713     status->status = InvalidArgument(
714         "Unparseable AttrValue proto passed to "
715         "TF_FunctionSetAttrValueProto");
716     return;
717   }
718   (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
719   status->status = tensorflow::Status::OK();
720 }
721 
TF_FunctionGetAttrValueProto(TF_Function * func,const char * attr_name,TF_Buffer * output_attr_value,TF_Status * status)722 void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
723                                   TF_Buffer* output_attr_value,
724                                   TF_Status* status) {
725   const auto& it = func->fdef.attr().find(attr_name);
726   if (it == func->fdef.attr().end()) {
727     status->status =
728         InvalidArgument("Function '", func->fdef.signature().name(),
729                         "' has no attr named '", attr_name, "'.");
730     return;
731   }
732   status->status = MessageToBuffer(it->second, output_attr_value);
733 }
734 
TF_DeleteFunction(TF_Function * func)735 void TF_DeleteFunction(TF_Function* func) { delete func; }
736