• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
17 
18 #include "tensorflow/core/framework/node_def_util.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 
23 namespace tensorflow {
24 namespace graph_transforms {
25 
26 namespace {
IsMerge(const NodeDef & node_def)27 inline bool IsMerge(const NodeDef& node_def) {
28   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
29 }
30 
RecordMatchedNodes(const NodeMatch & match,std::set<string> * matched_nodes)31 void RecordMatchedNodes(const NodeMatch& match,
32                         std::set<string>* matched_nodes) {
33   matched_nodes->insert(match.node.name());
34   for (const NodeMatch& input_match : match.inputs) {
35     RecordMatchedNodes(input_match, matched_nodes);
36   }
37 }
38 
Hash64String(const string & input)39 inline uint64 Hash64String(const string& input) {
40   return Hash64(input.data(), input.size());
41 }
42 }  // namespace
43 
MatchedNodesAsArray(const NodeMatch & match,std::vector<NodeDef> * result)44 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
45   std::set<string> found_nodes;
46   std::vector<NodeMatch> current_matches = {match};
47   while (!current_matches.empty()) {
48     std::vector<NodeMatch> next_matches;
49     for (const NodeMatch& current_match : current_matches) {
50       if (found_nodes.count(current_match.node.name())) {
51         continue;
52       }
53       found_nodes.insert(current_match.node.name());
54       result->push_back(current_match.node);
55       for (const NodeMatch& input_match : current_match.inputs) {
56         next_matches.push_back(input_match);
57       }
58     }
59     current_matches = next_matches;
60   }
61 }
62 
MapNamesToNodes(const GraphDef & graph_def,std::map<string,const NodeDef * > * result)63 void MapNamesToNodes(const GraphDef& graph_def,
64                      std::map<string, const NodeDef*>* result) {
65   for (const NodeDef& node : graph_def.node()) {
66     (*result)[node.name()] = &node;
67   }
68 }
69 
MapNodesToOutputs(const GraphDef & graph_def,std::map<string,std::vector<const NodeDef * >> * result)70 void MapNodesToOutputs(const GraphDef& graph_def,
71                        std::map<string, std::vector<const NodeDef*>>* result) {
72   std::map<string, const NodeDef*> node_map;
73   MapNamesToNodes(graph_def, &node_map);
74   for (const NodeDef& node : graph_def.node()) {
75     for (const string& input : node.input()) {
76       string input_node_name = NodeNameFromInput(input);
77       (*result)[input_node_name].push_back(&node);
78     }
79   }
80 }
81 
NodeNamePartsFromInput(const string & input_name,string * prefix,string * node_name,string * suffix)82 void NodeNamePartsFromInput(const string& input_name, string* prefix,
83                             string* node_name, string* suffix) {
84   std::vector<string> input_parts = str_util::Split(input_name, ':');
85   if (input_parts.size() < 2) {
86     *suffix = "";
87   } else {
88     *suffix = ":" + input_parts[1];
89   }
90   StringPiece node_name_piece(input_parts[0]);
91   if (str_util::ConsumePrefix(&node_name_piece, "^")) {
92     *prefix = "^";
93   } else {
94     *prefix = "";
95   }
96   *node_name = string(node_name_piece);
97 }
98 
NodeNameFromInput(const string & input_name)99 string NodeNameFromInput(const string& input_name) {
100   string prefix;
101   string node_name;
102   string suffix;
103   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
104   return node_name;
105 }
106 
CanonicalInputName(const string & input_name)107 string CanonicalInputName(const string& input_name) {
108   string prefix;
109   string node_name;
110   string suffix;
111   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
112   if (suffix.empty()) {
113     suffix = ":0";
114   }
115   return prefix + node_name + suffix;
116 }
117 
HashNodeDef(const NodeDef & node)118 uint64 HashNodeDef(const NodeDef& node) {
119   uint64 hash = Hash64String(node.op());
120   hash = Hash64Combine(hash, Hash64String(node.name()));
121   for (const string& input : node.input()) {
122     hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
123   }
124   hash = Hash64Combine(hash, Hash64String(node.device()));
125   std::vector<string> attr_names;
126   attr_names.reserve(node.attr().size());
127   for (const auto& attr : node.attr()) {
128     attr_names.push_back(attr.first);
129   }
130   std::sort(attr_names.begin(), attr_names.end());
131   string attr_serialized;
132   for (const string& attr_name : attr_names) {
133     auto attr = node.attr().at(attr_name);
134     attr.SerializeToString(&attr_serialized);
135     hash = Hash64Combine(hash, Hash64String(attr_serialized));
136   }
137   return hash;
138 }
139 
AddNodeInput(const string & input_name,NodeDef * node)140 void AddNodeInput(const string& input_name, NodeDef* node) {
141   *(node->mutable_input()->Add()) = input_name;
142 }
143 
CopyNodeAttr(const NodeDef & source,const string & source_key,const string & dest_key,NodeDef * dest)144 void CopyNodeAttr(const NodeDef& source, const string& source_key,
145                   const string& dest_key, NodeDef* dest) {
146   CHECK_NE(0, source.attr().count(source_key))
147       << "No key '" << source_key << "' found in " << source.DebugString();
148   (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key);
149 }
150 
GetNodeTensorAttr(const NodeDef & node,const string & key)151 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
152   TensorProto tensor_proto = node.attr().at(key).tensor();
153   Tensor tensor;
154   CHECK(tensor.FromProto(tensor_proto));
155   return tensor;
156 }
157 
FilterGraphDef(const GraphDef & input_graph_def,std::function<bool (const NodeDef &)> selector,GraphDef * output_graph_def)158 void FilterGraphDef(const GraphDef& input_graph_def,
159                     std::function<bool(const NodeDef&)> selector,
160                     GraphDef* output_graph_def) {
161   output_graph_def->mutable_node()->Clear();
162   for (const NodeDef& node : input_graph_def.node()) {
163     if (selector(node)) {
164       *output_graph_def->mutable_node()->Add() = node;
165     }
166   }
167 }
168 
RemoveAttributes(const GraphDef & input_graph_def,const std::vector<string> & attributes,GraphDef * output_graph_def)169 void RemoveAttributes(const GraphDef& input_graph_def,
170                       const std::vector<string>& attributes,
171                       GraphDef* output_graph_def) {
172   output_graph_def->mutable_node()->Clear();
173   for (const NodeDef& node : input_graph_def.node()) {
174     NodeDef* new_node = output_graph_def->mutable_node()->Add();
175     *new_node = node;
176     for (const string& attribute : attributes) {
177       new_node->mutable_attr()->erase(attribute);
178     }
179   }
180 }
181 
SortByExecutionOrder(const GraphDef & input_graph_def,GraphDef * output_graph_def)182 Status SortByExecutionOrder(const GraphDef& input_graph_def,
183                             GraphDef* output_graph_def) {
184   const int num_nodes = input_graph_def.node_size();
185   std::vector<int> ready;
186   std::vector<int> pending_count;
187   pending_count.reserve(num_nodes);
188   std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
189 
190   std::map<string, int> name_index;
191   for (int i = 0; i < input_graph_def.node_size(); ++i) {
192     const NodeDef& node(input_graph_def.node(i));
193     name_index[node.name()] = i;
194   }
195 
196   // Parse the inputs for each node.
197   for (int n = 0; n < num_nodes; ++n) {
198     const NodeDef& node_def(input_graph_def.node(n));
199     if (IsMerge(node_def)) {
200       // for merge only wait for one non-control input.
201       int32 num_control_edges = 0;
202       for (int i = 0; i < node_def.input_size(); ++i) {
203         if (str_util::StartsWith(node_def.input(i), "^")) {
204           num_control_edges++;
205         }
206       }
207       pending_count.push_back(num_control_edges + 1);
208     } else {
209       pending_count.push_back(node_def.input_size());
210     }
211     if (node_def.input_size() == 0) {
212       ready.push_back(n);
213       continue;
214     }
215     for (int i = 0; i < node_def.input_size(); ++i) {
216       const string& input_name = node_def.input(i);
217       const string& input_node_name = NodeNameFromInput(input_name);
218       if (!name_index.count(input_node_name)) {
219         return errors::InvalidArgument("Node '", node_def.name(),
220                                        "': Unknown input node '",
221                                        node_def.input(i), "'");
222       }
223       outputs[name_index[input_node_name]].push_back(n);
224     }
225   }
226 
227   int processed = 0;
228   output_graph_def->Clear();
229   // Process the NodeDefs in topological order.
230   // Code above sets this up by filling in ready_ with nodes that have no
231   // inputs, pending_counts_ with the number of inputs for each node and
232   // outputs_ with the outputs of each node.
233   while (!ready.empty()) {
234     int o = ready.back();
235     ready.pop_back();
236     ++processed;
237     const NodeDef& node_def(input_graph_def.node(o));
238     *output_graph_def->mutable_node()->Add() = node_def;
239 
240     // Update pending_count for outputs.
241     for (size_t i = 0; i < outputs[o].size(); ++i) {
242       const int output = outputs[o][i];
243       pending_count[output]--;
244       if (pending_count[output] == 0) {
245         ready.push_back(output);
246       }
247     }
248   }
249 
250   if (processed < num_nodes) {
251     LOG(WARNING) << "IN " << __func__ << (num_nodes - processed)
252                  << " NODES IN A CYCLE";
253     for (int64 i = 0; i < num_nodes; i++) {
254       if (pending_count[i] != 0) {
255         LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i))
256                      << "WITH PENDING COUNT = " << pending_count[i];
257       }
258     }
259     return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle");
260   }
261   return Status::OK();
262 }
263 
DebugString() const264 string OpTypePattern::DebugString() const {
265   string result = "{" + op + ", {";
266   for (const OpTypePattern& input : inputs) {
267     result += input.DebugString() + ",";
268   }
269   result += "}}";
270   return result;
271 }
272 
DebugString() const273 string NodeMatch::DebugString() const {
274   string result = "{";
275   result += node.DebugString();
276   result += ", {";
277   for (const NodeMatch& input : inputs) {
278     result += input.DebugString() + ",";
279   }
280   result += "}}";
281   return result;
282 }
283 
GraphMatcher(const GraphDef & graph_def)284 GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
285   SortByExecutionOrder(graph_def, &graph_def_).IgnoreError();
286   MapNamesToNodes(graph_def_, &node_map_);
287 }
288 
GetOpTypeMatches(const OpTypePattern & pattern,std::vector<NodeMatch> * matches)289 Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
290                                       std::vector<NodeMatch>* matches) {
291   std::set<string> matched_nodes;
292   for (const NodeDef& node : graph_def_.node()) {
293     // Skip any nodes that are already part of a match.
294     if (matched_nodes.count(node.name())) {
295       continue;
296     }
297     NodeMatch match;
298     if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
299       RecordMatchedNodes(match, &matched_nodes);
300       matches->push_back(match);
301     }
302   }
303   return Status::OK();
304 }
305 
DoesOpTypeMatch(const NodeDef & node,const OpTypePattern & pattern,const std::set<string> & previously_matched_nodes,NodeMatch * match)306 bool GraphMatcher::DoesOpTypeMatch(
307     const NodeDef& node, const OpTypePattern& pattern,
308     const std::set<string>& previously_matched_nodes, NodeMatch* match) {
309   VLOG(1) << "Looking at node " << node.DebugString();
310   VLOG(1) << "pattern=" << pattern.DebugString();
311   VLOG(1) << "match=" << match->DebugString();
312   if (previously_matched_nodes.count(node.name())) {
313     VLOG(1) << "node " << node.name() << " has been previously matched";
314     return false;
315   }
316   bool pattern_matched = false;
317   if (pattern.op == "*") {
318     pattern_matched = true;
319   } else {
320     std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
321     for (const string& pattern_op : pattern_ops) {
322       if (node.op() == pattern_op) {
323         pattern_matched = true;
324       }
325     }
326   }
327   if (!pattern_matched) {
328     VLOG(1) << "node.op() != pattern.op()";
329     return false;
330   }
331   match->node = node;
332   // Ignore any control inputs for pattern-matching purposes
333   std::vector<string> non_control_inputs;
334   for (const string& input : node.input()) {
335     if (!input.empty() && (input[0] != '^')) {
336       non_control_inputs.push_back(input);
337     }
338   }
339   if (pattern.inputs.empty()) {
340     // If there are no inputs, assume that's the end of the pattern.
341     return true;
342   }
343   if (non_control_inputs.size() != pattern.inputs.size()) {
344     VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
345     return false;
346   }
347   for (int i = 0; i < pattern.inputs.size(); ++i) {
348     const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
349     const NodeDef& input_node = *(node_map_[input_node_name]);
350     const OpTypePattern& input_pattern = pattern.inputs[i];
351     match->inputs.push_back(NodeMatch());
352     NodeMatch* input_match = &(match->inputs.back());
353     if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
354                          input_match)) {
355       return false;
356     }
357   }
358   return true;
359 }
360 
ReplaceMatchingOpTypes(const GraphDef & input_graph_def,const OpTypePattern & pattern,const std::function<Status (const NodeMatch &,const std::set<string> &,const std::set<string> &,std::vector<NodeDef> *)> & node_generator,const ReplaceMatchingOpTypesOptions & options,GraphDef * output_graph_def)361 Status ReplaceMatchingOpTypes(
362     const GraphDef& input_graph_def, const OpTypePattern& pattern,
363     const std::function<Status(const NodeMatch&, const std::set<string>&,
364                                const std::set<string>&, std::vector<NodeDef>*)>&
365         node_generator,
366     const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
367   // Start off by retrieving all the matching subgraphs.
368   GraphMatcher matcher(input_graph_def);
369   std::vector<NodeMatch> matches;
370   TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches));
371 
372   // Do some housekeeping so we can easily look up the resulting matches given
373   // a node name.
374   std::set<string> matched_nodes;
375   std::map<string, const NodeMatch*> matches_by_head_name;
376   for (const NodeMatch& match : matches) {
377     matches_by_head_name[match.node.name()] = &match;
378     RecordMatchedNodes(match, &matched_nodes);
379   }
380   std::map<string, std::vector<const NodeDef*>> outputs_map;
381   MapNodesToOutputs(input_graph_def, &outputs_map);
382 
383   // Go through all the nodes in the input graph, see if they are part of a
384   // match or if they can be left untouched.
385   output_graph_def->Clear();
386   for (const NodeDef& input_node : input_graph_def.node()) {
387     if (matches_by_head_name.count(input_node.name())) {
388       // This node is the beginning of a match, so call the replacement function
389       // after setting up some information it will need.
390       const NodeMatch* match = matches_by_head_name[input_node.name()];
391       std::vector<NodeDef> matched_nodes_array;
392       MatchedNodesAsArray(*match, &matched_nodes_array);
393       // This tells us whether a node is part of the current match.
394       std::set<string> matched_nodes_lookup;
395       for (const NodeDef& matched_node : matched_nodes_array) {
396         matched_nodes_lookup.insert(matched_node.name());
397       }
398       // These are helper arrays that the replacement function can use to tell
399       // whether it can safely remove an internal node (because nothing outside
400       // of the match uses it) or whether external nodes depend on it.
401       std::set<string> input_nodes;
402       std::set<string> output_nodes;
403       for (const NodeDef& matched_node : matched_nodes_array) {
404         // Look through all of this node's inputs, and if any of them come from
405         // outside the match, then this should be noted as one of the external
406         // inputs of the subgraph.
407         for (const string& input_name : matched_node.input()) {
408           string input_node_name = NodeNameFromInput(input_name);
409           if (!matched_nodes_lookup.count(input_node_name)) {
410             input_nodes.insert(matched_node.name());
411           }
412         }
413         // Do a reverse input lookup, to see which other nodes use the current
414         // one as an input. If any of those nodes are outside the match
415         // subgraph, then the current node is marked as an output node that
416         // shouldn't be removed.
417         if (outputs_map.count(matched_node.name())) {
418           for (const NodeDef* dependent_node :
419                outputs_map[matched_node.name()]) {
420             if (!matched_nodes_lookup.count(dependent_node->name())) {
421               output_nodes.insert(matched_node.name());
422             }
423           }
424         }
425       }
426       // Call the generator function and add all the returned nodes to the
427       // graph.
428       std::vector<NodeDef> new_nodes;
429       TF_RETURN_IF_ERROR(
430           node_generator(*match, input_nodes, output_nodes, &new_nodes));
431       std::set<string> new_node_names;
432       for (const NodeDef& new_node : new_nodes) {
433         new_node_names.insert(new_node.name());
434       }
435       // Check to make sure the generator function preserved all of the nodes
436       // that are used elsewhere in the graph, and add them back in if not.
437       bool abort_replacement = false;
438       if (!options.allow_inconsistencies) {
439         for (const string& expected_output : output_nodes) {
440           if (!new_node_names.count(expected_output)) {
441             LOG(WARNING) << "Expected " << expected_output
442                          << " to be preserved.";
443             abort_replacement = true;
444           }
445         }
446       }
447       if (abort_replacement) {
448         LOG(WARNING) << "Generator function didn't preserve needed nodes, "
449                      << "copying old replacements back in instead.";
450         std::vector<NodeDef> old_nodes;
451         MatchedNodesAsArray(*match, &old_nodes);
452         for (const NodeDef& old_node : old_nodes) {
453           NodeDef* added_node = output_graph_def->mutable_node()->Add();
454           *added_node = old_node;
455         }
456       } else {
457         for (const NodeDef& new_node : new_nodes) {
458           NodeDef* added_node = output_graph_def->mutable_node()->Add();
459           *added_node = new_node;
460         }
461       }
462     } else if (!matched_nodes.count(input_node.name())) {
463       // This node isn't part of any match, so just copy it over.
464       NodeDef* added_node = output_graph_def->mutable_node()->Add();
465       *added_node = input_node;
466     } else {
467       // Do nothing, because this is an internal part of a matching subgraph,
468       // and so will have been replaced by a new replacement subgraph.
469     }
470   }
471 
472   return Status::OK();
473 }
474 
RenameNodeInputs(const GraphDef & input_graph_def,const std::map<string,string> & inputs_to_rename,const std::unordered_set<string> & nodes_to_ignore,GraphDef * output_graph_def)475 Status RenameNodeInputs(const GraphDef& input_graph_def,
476                         const std::map<string, string>& inputs_to_rename,
477                         const std::unordered_set<string>& nodes_to_ignore,
478                         GraphDef* output_graph_def) {
479   std::map<string, std::vector<std::pair<string, string>>>
480       canonical_inputs_to_rename;
481   for (const auto& input_to_rename : inputs_to_rename) {
482     canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
483         .push_back({input_to_rename.first, input_to_rename.second});
484   }
485 
486   output_graph_def->Clear();
487   for (const NodeDef& node : input_graph_def.node()) {
488     NodeDef* new_node = output_graph_def->mutable_node()->Add();
489     *new_node = node;
490     new_node->mutable_input()->Clear();
491     for (const string& input_name : node.input()) {
492       std::set<string> already_visited;
493       string new_input_name = input_name;
494       while (
495           canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
496         string input_node_name = NodeNameFromInput(new_input_name);
497         if (already_visited.count(input_node_name)) {
498           return errors::InvalidArgument(
499               "RenameNodeInputs argument contains a cycle for ",
500               input_node_name);
501         }
502         already_visited.insert(input_node_name);
503         if (nodes_to_ignore.count(node.name())) {
504           break;
505         }
506         bool any_match_found = false;
507         for (const std::pair<string, string>& input_to_rename :
508              canonical_inputs_to_rename.at(input_node_name)) {
509           const string& source_name = input_to_rename.first;
510           const string& dest_name = input_to_rename.second;
511           bool is_match;
512           string match_name;
513           if (str_util::EndsWith(source_name, ":*")) {
514             is_match = true;
515             string prefix;
516             string unused_node_name;
517             string suffix;
518             NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
519                                    &suffix);
520             match_name = prefix + dest_name + suffix;
521           } else {
522             is_match = (CanonicalInputName(source_name) ==
523                         CanonicalInputName(new_input_name));
524             match_name = dest_name;
525           }
526           if (is_match) {
527             new_input_name = match_name;
528             any_match_found = true;
529           }
530         }
531         if (!any_match_found) {
532           break;
533         }
534       }
535       *(new_node->mutable_input()->Add()) = new_input_name;
536     }
537   }
538   return Status::OK();
539 }
540 
CopyOriginalMatch(const NodeMatch & match,std::vector<NodeDef> * new_nodes)541 void CopyOriginalMatch(const NodeMatch& match,
542                        std::vector<NodeDef>* new_nodes) {
543   std::vector<NodeDef> old_nodes;
544   MatchedNodesAsArray(match, &old_nodes);
545   for (const NodeDef& old_node : old_nodes) {
546     new_nodes->push_back(old_node);
547   }
548 }
549 
GetTransformRegistry()550 TransformRegistry* GetTransformRegistry() {
551   static TransformRegistry transform_registry;
552   return &transform_registry;
553 }
554 
FindInvalidInputs(const GraphDef & graph_def,std::vector<std::pair<string,string>> * invalid_inputs)555 void FindInvalidInputs(const GraphDef& graph_def,
556                        std::vector<std::pair<string, string>>* invalid_inputs) {
557   std::map<string, const NodeDef*> node_map;
558   MapNamesToNodes(graph_def, &node_map);
559 
560   for (const NodeDef& node : graph_def.node()) {
561     for (const string& input : node.input()) {
562       string input_node = NodeNameFromInput(input);
563       if (!node_map.count(input_node)) {
564         invalid_inputs->push_back({node.name(), input_node});
565       }
566     }
567   }
568 }
569 
IsGraphValid(const GraphDef & graph_def)570 Status IsGraphValid(const GraphDef& graph_def) {
571   std::vector<std::pair<string, string>> invalid_inputs;
572   FindInvalidInputs(graph_def, &invalid_inputs);
573   if (!invalid_inputs.empty()) {
574     std::map<string, const NodeDef*> node_map;
575     MapNamesToNodes(graph_def, &node_map);
576     for (const std::pair<string, string>& invalid_input : invalid_inputs) {
577       LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
578                  << invalid_input.first << " - "
579                  << node_map[invalid_input.first]->DebugString();
580     }
581     return errors::Internal(
582         "Invalid graph with inputs referring to nonexistent nodes");
583   }
584   return Status::OK();
585 }
586 
GetInOutTypes(const NodeDef & node_def,DataTypeVector * inputs,DataTypeVector * outputs)587 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
588                      DataTypeVector* outputs) {
589   const OpDef* op_def;
590   TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
591   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
592   return Status::OK();
593 }
594 
TensorShapeFromString(const string & shape_string,TensorShape * result)595 Status TensorShapeFromString(const string& shape_string, TensorShape* result) {
596   if (shape_string.empty()) {
597     return errors::InvalidArgument("Specificed shape is empty.");
598   }
599   std::vector<int64> dims;
600   if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
601     return errors::InvalidArgument("Could parse as shape: '", shape_string,
602                                    "'");
603   }
604   *result = TensorShape(dims);
605   return Status::OK();
606 }
607 
CountParameters(const string & name) const608 int TransformFuncContext::CountParameters(const string& name) const {
609   if (params.count(name)) {
610     return params.at(name).size();
611   } else {
612     return 0;
613   }
614 }
615 
GetOneStringParameter(const string & name,const string & default_value,string * result) const616 Status TransformFuncContext::GetOneStringParameter(const string& name,
617                                                    const string& default_value,
618                                                    string* result) const {
619   const int params_count = CountParameters(name);
620   if (params_count == 0) {
621     *result = default_value;
622     return Status::OK();
623   } else if (params_count == 1) {
624     *result = params.at(name).at(0);
625     return Status::OK();
626   } else {
627     return errors::InvalidArgument("Expected a single '", name,
628                                    "' parameter, but found ", params_count,
629                                    " occurrences");
630   }
631 }
632 
GetOneInt32Parameter(const string & name,int32 default_value,int32 * result) const633 Status TransformFuncContext::GetOneInt32Parameter(const string& name,
634                                                   int32 default_value,
635                                                   int32* result) const {
636   const int params_count = CountParameters(name);
637   if (params_count == 0) {
638     *result = default_value;
639     return Status::OK();
640   }
641   string string_value;
642   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
643   if (!strings::safe_strto32(StringPiece(string_value), result)) {
644     return errors::InvalidArgument("Couldn't interpret the ", name,
645                                    " argument as a number:", string_value);
646   }
647   return Status::OK();
648 }
649 
GetOneInt64Parameter(const string & name,int64 default_value,int64 * result) const650 Status TransformFuncContext::GetOneInt64Parameter(const string& name,
651                                                   int64 default_value,
652                                                   int64* result) const {
653   const int params_count = CountParameters(name);
654   if (params_count == 0) {
655     *result = default_value;
656     return Status::OK();
657   }
658   string string_value;
659   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
660   if (!strings::safe_strto64(StringPiece(string_value), result)) {
661     return errors::InvalidArgument("Couldn't interpret the ", name,
662                                    " argument as a number:", string_value);
663   }
664   return Status::OK();
665 }
666 
GetOneFloatParameter(const string & name,float default_value,float * result) const667 Status TransformFuncContext::GetOneFloatParameter(const string& name,
668                                                   float default_value,
669                                                   float* result) const {
670   const int params_count = CountParameters(name);
671   if (params_count == 0) {
672     *result = default_value;
673     return Status::OK();
674   }
675   string string_value;
676   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
677   if (!strings::safe_strtof(string_value.c_str(), result)) {
678     return errors::InvalidArgument(
679         "Couldn't interpret the ", name,
680         " argument as a float number:", string_value);
681   }
682   return Status::OK();
683 }
684 
GetOneBoolParameter(const string & name,bool default_value,bool * result) const685 Status TransformFuncContext::GetOneBoolParameter(const string& name,
686                                                  bool default_value,
687                                                  bool* result) const {
688   const int params_count = CountParameters(name);
689   if (params_count == 0) {
690     *result = default_value;
691     return Status::OK();
692   }
693   string string_value;
694   TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
695   if (string_value == "true" || string_value == "1") {
696     *result = true;
697   } else if (string_value == "false" || string_value == "0") {
698     *result = false;
699   } else {
700     return errors::InvalidArgument("Couldn't interpret the ", name,
701                                    " argument as a boolean:", string_value,
702                                    " (expected true, false, 0 or 1)");
703   }
704   return Status::OK();
705 }
706 
707 }  // namespace graph_transforms
708 }  // namespace tensorflow
709