1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
17
18 #include "tensorflow/core/framework/node_def_util.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22
23 namespace tensorflow {
24 namespace graph_transforms {
25
26 namespace {
IsMerge(const NodeDef & node_def)27 inline bool IsMerge(const NodeDef& node_def) {
28 return node_def.op() == "Merge" || node_def.op() == "RefMerge";
29 }
30
RecordMatchedNodes(const NodeMatch & match,std::set<string> * matched_nodes)31 void RecordMatchedNodes(const NodeMatch& match,
32 std::set<string>* matched_nodes) {
33 matched_nodes->insert(match.node.name());
34 for (const NodeMatch& input_match : match.inputs) {
35 RecordMatchedNodes(input_match, matched_nodes);
36 }
37 }
38
Hash64String(const string & input)39 inline uint64 Hash64String(const string& input) {
40 return Hash64(input.data(), input.size());
41 }
42 } // namespace
43
MatchedNodesAsArray(const NodeMatch & match,std::vector<NodeDef> * result)44 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
45 std::set<string> found_nodes;
46 std::vector<NodeMatch> current_matches = {match};
47 while (!current_matches.empty()) {
48 std::vector<NodeMatch> next_matches;
49 for (const NodeMatch& current_match : current_matches) {
50 if (found_nodes.count(current_match.node.name())) {
51 continue;
52 }
53 found_nodes.insert(current_match.node.name());
54 result->push_back(current_match.node);
55 for (const NodeMatch& input_match : current_match.inputs) {
56 next_matches.push_back(input_match);
57 }
58 }
59 current_matches = next_matches;
60 }
61 }
62
MapNamesToNodes(const GraphDef & graph_def,std::map<string,const NodeDef * > * result)63 void MapNamesToNodes(const GraphDef& graph_def,
64 std::map<string, const NodeDef*>* result) {
65 for (const NodeDef& node : graph_def.node()) {
66 (*result)[node.name()] = &node;
67 }
68 }
69
MapNodesToOutputs(const GraphDef & graph_def,std::map<string,std::vector<const NodeDef * >> * result)70 void MapNodesToOutputs(const GraphDef& graph_def,
71 std::map<string, std::vector<const NodeDef*>>* result) {
72 std::map<string, const NodeDef*> node_map;
73 MapNamesToNodes(graph_def, &node_map);
74 for (const NodeDef& node : graph_def.node()) {
75 for (const string& input : node.input()) {
76 string input_node_name = NodeNameFromInput(input);
77 (*result)[input_node_name].push_back(&node);
78 }
79 }
80 }
81
NodeNamePartsFromInput(const string & input_name,string * prefix,string * node_name,string * suffix)82 void NodeNamePartsFromInput(const string& input_name, string* prefix,
83 string* node_name, string* suffix) {
84 std::vector<string> input_parts = str_util::Split(input_name, ':');
85 if (input_parts.size() < 2) {
86 *suffix = "";
87 } else {
88 *suffix = ":" + input_parts[1];
89 }
90 StringPiece node_name_piece(input_parts[0]);
91 if (str_util::ConsumePrefix(&node_name_piece, "^")) {
92 *prefix = "^";
93 } else {
94 *prefix = "";
95 }
96 *node_name = string(node_name_piece);
97 }
98
NodeNameFromInput(const string & input_name)99 string NodeNameFromInput(const string& input_name) {
100 string prefix;
101 string node_name;
102 string suffix;
103 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
104 return node_name;
105 }
106
CanonicalInputName(const string & input_name)107 string CanonicalInputName(const string& input_name) {
108 string prefix;
109 string node_name;
110 string suffix;
111 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
112 if (suffix.empty()) {
113 suffix = ":0";
114 }
115 return prefix + node_name + suffix;
116 }
117
HashNodeDef(const NodeDef & node)118 uint64 HashNodeDef(const NodeDef& node) {
119 uint64 hash = Hash64String(node.op());
120 hash = Hash64Combine(hash, Hash64String(node.name()));
121 for (const string& input : node.input()) {
122 hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
123 }
124 hash = Hash64Combine(hash, Hash64String(node.device()));
125 std::vector<string> attr_names;
126 attr_names.reserve(node.attr().size());
127 for (const auto& attr : node.attr()) {
128 attr_names.push_back(attr.first);
129 }
130 std::sort(attr_names.begin(), attr_names.end());
131 string attr_serialized;
132 for (const string& attr_name : attr_names) {
133 auto attr = node.attr().at(attr_name);
134 attr.SerializeToString(&attr_serialized);
135 hash = Hash64Combine(hash, Hash64String(attr_serialized));
136 }
137 return hash;
138 }
139
AddNodeInput(const string & input_name,NodeDef * node)140 void AddNodeInput(const string& input_name, NodeDef* node) {
141 *(node->mutable_input()->Add()) = input_name;
142 }
143
CopyNodeAttr(const NodeDef & source,const string & source_key,const string & dest_key,NodeDef * dest)144 void CopyNodeAttr(const NodeDef& source, const string& source_key,
145 const string& dest_key, NodeDef* dest) {
146 CHECK_NE(0, source.attr().count(source_key))
147 << "No key '" << source_key << "' found in " << source.DebugString();
148 (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key);
149 }
150
GetNodeTensorAttr(const NodeDef & node,const string & key)151 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
152 TensorProto tensor_proto = node.attr().at(key).tensor();
153 Tensor tensor;
154 CHECK(tensor.FromProto(tensor_proto));
155 return tensor;
156 }
157
FilterGraphDef(const GraphDef & input_graph_def,std::function<bool (const NodeDef &)> selector,GraphDef * output_graph_def)158 void FilterGraphDef(const GraphDef& input_graph_def,
159 std::function<bool(const NodeDef&)> selector,
160 GraphDef* output_graph_def) {
161 output_graph_def->mutable_node()->Clear();
162 for (const NodeDef& node : input_graph_def.node()) {
163 if (selector(node)) {
164 *output_graph_def->mutable_node()->Add() = node;
165 }
166 }
167 }
168
RemoveAttributes(const GraphDef & input_graph_def,const std::vector<string> & attributes,GraphDef * output_graph_def)169 void RemoveAttributes(const GraphDef& input_graph_def,
170 const std::vector<string>& attributes,
171 GraphDef* output_graph_def) {
172 output_graph_def->mutable_node()->Clear();
173 for (const NodeDef& node : input_graph_def.node()) {
174 NodeDef* new_node = output_graph_def->mutable_node()->Add();
175 *new_node = node;
176 for (const string& attribute : attributes) {
177 new_node->mutable_attr()->erase(attribute);
178 }
179 }
180 }
181
SortByExecutionOrder(const GraphDef & input_graph_def,GraphDef * output_graph_def)182 Status SortByExecutionOrder(const GraphDef& input_graph_def,
183 GraphDef* output_graph_def) {
184 const int num_nodes = input_graph_def.node_size();
185 std::vector<int> ready;
186 std::vector<int> pending_count;
187 pending_count.reserve(num_nodes);
188 std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
189
190 std::map<string, int> name_index;
191 for (int i = 0; i < input_graph_def.node_size(); ++i) {
192 const NodeDef& node(input_graph_def.node(i));
193 name_index[node.name()] = i;
194 }
195
196 // Parse the inputs for each node.
197 for (int n = 0; n < num_nodes; ++n) {
198 const NodeDef& node_def(input_graph_def.node(n));
199 if (IsMerge(node_def)) {
200 // for merge only wait for one non-control input.
201 int32 num_control_edges = 0;
202 for (int i = 0; i < node_def.input_size(); ++i) {
203 if (str_util::StartsWith(node_def.input(i), "^")) {
204 num_control_edges++;
205 }
206 }
207 pending_count.push_back(num_control_edges + 1);
208 } else {
209 pending_count.push_back(node_def.input_size());
210 }
211 if (node_def.input_size() == 0) {
212 ready.push_back(n);
213 continue;
214 }
215 for (int i = 0; i < node_def.input_size(); ++i) {
216 const string& input_name = node_def.input(i);
217 const string& input_node_name = NodeNameFromInput(input_name);
218 if (!name_index.count(input_node_name)) {
219 return errors::InvalidArgument("Node '", node_def.name(),
220 "': Unknown input node '",
221 node_def.input(i), "'");
222 }
223 outputs[name_index[input_node_name]].push_back(n);
224 }
225 }
226
227 int processed = 0;
228 output_graph_def->Clear();
229 // Process the NodeDefs in topological order.
230 // Code above sets this up by filling in ready_ with nodes that have no
231 // inputs, pending_counts_ with the number of inputs for each node and
232 // outputs_ with the outputs of each node.
233 while (!ready.empty()) {
234 int o = ready.back();
235 ready.pop_back();
236 ++processed;
237 const NodeDef& node_def(input_graph_def.node(o));
238 *output_graph_def->mutable_node()->Add() = node_def;
239
240 // Update pending_count for outputs.
241 for (size_t i = 0; i < outputs[o].size(); ++i) {
242 const int output = outputs[o][i];
243 pending_count[output]--;
244 if (pending_count[output] == 0) {
245 ready.push_back(output);
246 }
247 }
248 }
249
250 if (processed < num_nodes) {
251 LOG(WARNING) << "IN " << __func__ << (num_nodes - processed)
252 << " NODES IN A CYCLE";
253 for (int64 i = 0; i < num_nodes; i++) {
254 if (pending_count[i] != 0) {
255 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i))
256 << "WITH PENDING COUNT = " << pending_count[i];
257 }
258 }
259 return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle");
260 }
261 return Status::OK();
262 }
263
DebugString() const264 string OpTypePattern::DebugString() const {
265 string result = "{" + op + ", {";
266 for (const OpTypePattern& input : inputs) {
267 result += input.DebugString() + ",";
268 }
269 result += "}}";
270 return result;
271 }
272
DebugString() const273 string NodeMatch::DebugString() const {
274 string result = "{";
275 result += node.DebugString();
276 result += ", {";
277 for (const NodeMatch& input : inputs) {
278 result += input.DebugString() + ",";
279 }
280 result += "}}";
281 return result;
282 }
283
GraphMatcher(const GraphDef & graph_def)284 GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
285 SortByExecutionOrder(graph_def, &graph_def_).IgnoreError();
286 MapNamesToNodes(graph_def_, &node_map_);
287 }
288
GetOpTypeMatches(const OpTypePattern & pattern,std::vector<NodeMatch> * matches)289 Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
290 std::vector<NodeMatch>* matches) {
291 std::set<string> matched_nodes;
292 for (const NodeDef& node : graph_def_.node()) {
293 // Skip any nodes that are already part of a match.
294 if (matched_nodes.count(node.name())) {
295 continue;
296 }
297 NodeMatch match;
298 if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
299 RecordMatchedNodes(match, &matched_nodes);
300 matches->push_back(match);
301 }
302 }
303 return Status::OK();
304 }
305
DoesOpTypeMatch(const NodeDef & node,const OpTypePattern & pattern,const std::set<string> & previously_matched_nodes,NodeMatch * match)306 bool GraphMatcher::DoesOpTypeMatch(
307 const NodeDef& node, const OpTypePattern& pattern,
308 const std::set<string>& previously_matched_nodes, NodeMatch* match) {
309 VLOG(1) << "Looking at node " << node.DebugString();
310 VLOG(1) << "pattern=" << pattern.DebugString();
311 VLOG(1) << "match=" << match->DebugString();
312 if (previously_matched_nodes.count(node.name())) {
313 VLOG(1) << "node " << node.name() << " has been previously matched";
314 return false;
315 }
316 bool pattern_matched = false;
317 if (pattern.op == "*") {
318 pattern_matched = true;
319 } else {
320 std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
321 for (const string& pattern_op : pattern_ops) {
322 if (node.op() == pattern_op) {
323 pattern_matched = true;
324 }
325 }
326 }
327 if (!pattern_matched) {
328 VLOG(1) << "node.op() != pattern.op()";
329 return false;
330 }
331 match->node = node;
332 // Ignore any control inputs for pattern-matching purposes
333 std::vector<string> non_control_inputs;
334 for (const string& input : node.input()) {
335 if (!input.empty() && (input[0] != '^')) {
336 non_control_inputs.push_back(input);
337 }
338 }
339 if (pattern.inputs.empty()) {
340 // If there are no inputs, assume that's the end of the pattern.
341 return true;
342 }
343 if (non_control_inputs.size() != pattern.inputs.size()) {
344 VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
345 return false;
346 }
347 for (int i = 0; i < pattern.inputs.size(); ++i) {
348 const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
349 const NodeDef& input_node = *(node_map_[input_node_name]);
350 const OpTypePattern& input_pattern = pattern.inputs[i];
351 match->inputs.push_back(NodeMatch());
352 NodeMatch* input_match = &(match->inputs.back());
353 if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
354 input_match)) {
355 return false;
356 }
357 }
358 return true;
359 }
360
ReplaceMatchingOpTypes(const GraphDef & input_graph_def,const OpTypePattern & pattern,const std::function<Status (const NodeMatch &,const std::set<string> &,const std::set<string> &,std::vector<NodeDef> *)> & node_generator,const ReplaceMatchingOpTypesOptions & options,GraphDef * output_graph_def)361 Status ReplaceMatchingOpTypes(
362 const GraphDef& input_graph_def, const OpTypePattern& pattern,
363 const std::function<Status(const NodeMatch&, const std::set<string>&,
364 const std::set<string>&, std::vector<NodeDef>*)>&
365 node_generator,
366 const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
367 // Start off by retrieving all the matching subgraphs.
368 GraphMatcher matcher(input_graph_def);
369 std::vector<NodeMatch> matches;
370 TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches));
371
372 // Do some housekeeping so we can easily look up the resulting matches given
373 // a node name.
374 std::set<string> matched_nodes;
375 std::map<string, const NodeMatch*> matches_by_head_name;
376 for (const NodeMatch& match : matches) {
377 matches_by_head_name[match.node.name()] = &match;
378 RecordMatchedNodes(match, &matched_nodes);
379 }
380 std::map<string, std::vector<const NodeDef*>> outputs_map;
381 MapNodesToOutputs(input_graph_def, &outputs_map);
382
383 // Go through all the nodes in the input graph, see if they are part of a
384 // match or if they can be left untouched.
385 output_graph_def->Clear();
386 for (const NodeDef& input_node : input_graph_def.node()) {
387 if (matches_by_head_name.count(input_node.name())) {
388 // This node is the beginning of a match, so call the replacement function
389 // after setting up some information it will need.
390 const NodeMatch* match = matches_by_head_name[input_node.name()];
391 std::vector<NodeDef> matched_nodes_array;
392 MatchedNodesAsArray(*match, &matched_nodes_array);
393 // This tells us whether a node is part of the current match.
394 std::set<string> matched_nodes_lookup;
395 for (const NodeDef& matched_node : matched_nodes_array) {
396 matched_nodes_lookup.insert(matched_node.name());
397 }
398 // These are helper arrays that the replacement function can use to tell
399 // whether it can safely remove an internal node (because nothing outside
400 // of the match uses it) or whether external nodes depend on it.
401 std::set<string> input_nodes;
402 std::set<string> output_nodes;
403 for (const NodeDef& matched_node : matched_nodes_array) {
404 // Look through all of this node's inputs, and if any of them come from
405 // outside the match, then this should be noted as one of the external
406 // inputs of the subgraph.
407 for (const string& input_name : matched_node.input()) {
408 string input_node_name = NodeNameFromInput(input_name);
409 if (!matched_nodes_lookup.count(input_node_name)) {
410 input_nodes.insert(matched_node.name());
411 }
412 }
413 // Do a reverse input lookup, to see which other nodes use the current
414 // one as an input. If any of those nodes are outside the match
415 // subgraph, then the current node is marked as an output node that
416 // shouldn't be removed.
417 if (outputs_map.count(matched_node.name())) {
418 for (const NodeDef* dependent_node :
419 outputs_map[matched_node.name()]) {
420 if (!matched_nodes_lookup.count(dependent_node->name())) {
421 output_nodes.insert(matched_node.name());
422 }
423 }
424 }
425 }
426 // Call the generator function and add all the returned nodes to the
427 // graph.
428 std::vector<NodeDef> new_nodes;
429 TF_RETURN_IF_ERROR(
430 node_generator(*match, input_nodes, output_nodes, &new_nodes));
431 std::set<string> new_node_names;
432 for (const NodeDef& new_node : new_nodes) {
433 new_node_names.insert(new_node.name());
434 }
435 // Check to make sure the generator function preserved all of the nodes
436 // that are used elsewhere in the graph, and add them back in if not.
437 bool abort_replacement = false;
438 if (!options.allow_inconsistencies) {
439 for (const string& expected_output : output_nodes) {
440 if (!new_node_names.count(expected_output)) {
441 LOG(WARNING) << "Expected " << expected_output
442 << " to be preserved.";
443 abort_replacement = true;
444 }
445 }
446 }
447 if (abort_replacement) {
448 LOG(WARNING) << "Generator function didn't preserve needed nodes, "
449 << "copying old replacements back in instead.";
450 std::vector<NodeDef> old_nodes;
451 MatchedNodesAsArray(*match, &old_nodes);
452 for (const NodeDef& old_node : old_nodes) {
453 NodeDef* added_node = output_graph_def->mutable_node()->Add();
454 *added_node = old_node;
455 }
456 } else {
457 for (const NodeDef& new_node : new_nodes) {
458 NodeDef* added_node = output_graph_def->mutable_node()->Add();
459 *added_node = new_node;
460 }
461 }
462 } else if (!matched_nodes.count(input_node.name())) {
463 // This node isn't part of any match, so just copy it over.
464 NodeDef* added_node = output_graph_def->mutable_node()->Add();
465 *added_node = input_node;
466 } else {
467 // Do nothing, because this is an internal part of a matching subgraph,
468 // and so will have been replaced by a new replacement subgraph.
469 }
470 }
471
472 return Status::OK();
473 }
474
RenameNodeInputs(const GraphDef & input_graph_def,const std::map<string,string> & inputs_to_rename,const std::unordered_set<string> & nodes_to_ignore,GraphDef * output_graph_def)475 Status RenameNodeInputs(const GraphDef& input_graph_def,
476 const std::map<string, string>& inputs_to_rename,
477 const std::unordered_set<string>& nodes_to_ignore,
478 GraphDef* output_graph_def) {
479 std::map<string, std::vector<std::pair<string, string>>>
480 canonical_inputs_to_rename;
481 for (const auto& input_to_rename : inputs_to_rename) {
482 canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
483 .push_back({input_to_rename.first, input_to_rename.second});
484 }
485
486 output_graph_def->Clear();
487 for (const NodeDef& node : input_graph_def.node()) {
488 NodeDef* new_node = output_graph_def->mutable_node()->Add();
489 *new_node = node;
490 new_node->mutable_input()->Clear();
491 for (const string& input_name : node.input()) {
492 std::set<string> already_visited;
493 string new_input_name = input_name;
494 while (
495 canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
496 string input_node_name = NodeNameFromInput(new_input_name);
497 if (already_visited.count(input_node_name)) {
498 return errors::InvalidArgument(
499 "RenameNodeInputs argument contains a cycle for ",
500 input_node_name);
501 }
502 already_visited.insert(input_node_name);
503 if (nodes_to_ignore.count(node.name())) {
504 break;
505 }
506 bool any_match_found = false;
507 for (const std::pair<string, string>& input_to_rename :
508 canonical_inputs_to_rename.at(input_node_name)) {
509 const string& source_name = input_to_rename.first;
510 const string& dest_name = input_to_rename.second;
511 bool is_match;
512 string match_name;
513 if (str_util::EndsWith(source_name, ":*")) {
514 is_match = true;
515 string prefix;
516 string unused_node_name;
517 string suffix;
518 NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
519 &suffix);
520 match_name = prefix + dest_name + suffix;
521 } else {
522 is_match = (CanonicalInputName(source_name) ==
523 CanonicalInputName(new_input_name));
524 match_name = dest_name;
525 }
526 if (is_match) {
527 new_input_name = match_name;
528 any_match_found = true;
529 }
530 }
531 if (!any_match_found) {
532 break;
533 }
534 }
535 *(new_node->mutable_input()->Add()) = new_input_name;
536 }
537 }
538 return Status::OK();
539 }
540
CopyOriginalMatch(const NodeMatch & match,std::vector<NodeDef> * new_nodes)541 void CopyOriginalMatch(const NodeMatch& match,
542 std::vector<NodeDef>* new_nodes) {
543 std::vector<NodeDef> old_nodes;
544 MatchedNodesAsArray(match, &old_nodes);
545 for (const NodeDef& old_node : old_nodes) {
546 new_nodes->push_back(old_node);
547 }
548 }
549
GetTransformRegistry()550 TransformRegistry* GetTransformRegistry() {
551 static TransformRegistry transform_registry;
552 return &transform_registry;
553 }
554
FindInvalidInputs(const GraphDef & graph_def,std::vector<std::pair<string,string>> * invalid_inputs)555 void FindInvalidInputs(const GraphDef& graph_def,
556 std::vector<std::pair<string, string>>* invalid_inputs) {
557 std::map<string, const NodeDef*> node_map;
558 MapNamesToNodes(graph_def, &node_map);
559
560 for (const NodeDef& node : graph_def.node()) {
561 for (const string& input : node.input()) {
562 string input_node = NodeNameFromInput(input);
563 if (!node_map.count(input_node)) {
564 invalid_inputs->push_back({node.name(), input_node});
565 }
566 }
567 }
568 }
569
IsGraphValid(const GraphDef & graph_def)570 Status IsGraphValid(const GraphDef& graph_def) {
571 std::vector<std::pair<string, string>> invalid_inputs;
572 FindInvalidInputs(graph_def, &invalid_inputs);
573 if (!invalid_inputs.empty()) {
574 std::map<string, const NodeDef*> node_map;
575 MapNamesToNodes(graph_def, &node_map);
576 for (const std::pair<string, string>& invalid_input : invalid_inputs) {
577 LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
578 << invalid_input.first << " - "
579 << node_map[invalid_input.first]->DebugString();
580 }
581 return errors::Internal(
582 "Invalid graph with inputs referring to nonexistent nodes");
583 }
584 return Status::OK();
585 }
586
GetInOutTypes(const NodeDef & node_def,DataTypeVector * inputs,DataTypeVector * outputs)587 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
588 DataTypeVector* outputs) {
589 const OpDef* op_def;
590 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
591 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
592 return Status::OK();
593 }
594
TensorShapeFromString(const string & shape_string,TensorShape * result)595 Status TensorShapeFromString(const string& shape_string, TensorShape* result) {
596 if (shape_string.empty()) {
597 return errors::InvalidArgument("Specificed shape is empty.");
598 }
599 std::vector<int64> dims;
600 if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
601 return errors::InvalidArgument("Could parse as shape: '", shape_string,
602 "'");
603 }
604 *result = TensorShape(dims);
605 return Status::OK();
606 }
607
CountParameters(const string & name) const608 int TransformFuncContext::CountParameters(const string& name) const {
609 if (params.count(name)) {
610 return params.at(name).size();
611 } else {
612 return 0;
613 }
614 }
615
GetOneStringParameter(const string & name,const string & default_value,string * result) const616 Status TransformFuncContext::GetOneStringParameter(const string& name,
617 const string& default_value,
618 string* result) const {
619 const int params_count = CountParameters(name);
620 if (params_count == 0) {
621 *result = default_value;
622 return Status::OK();
623 } else if (params_count == 1) {
624 *result = params.at(name).at(0);
625 return Status::OK();
626 } else {
627 return errors::InvalidArgument("Expected a single '", name,
628 "' parameter, but found ", params_count,
629 " occurrences");
630 }
631 }
632
GetOneInt32Parameter(const string & name,int32 default_value,int32 * result) const633 Status TransformFuncContext::GetOneInt32Parameter(const string& name,
634 int32 default_value,
635 int32* result) const {
636 const int params_count = CountParameters(name);
637 if (params_count == 0) {
638 *result = default_value;
639 return Status::OK();
640 }
641 string string_value;
642 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
643 if (!strings::safe_strto32(StringPiece(string_value), result)) {
644 return errors::InvalidArgument("Couldn't interpret the ", name,
645 " argument as a number:", string_value);
646 }
647 return Status::OK();
648 }
649
GetOneInt64Parameter(const string & name,int64 default_value,int64 * result) const650 Status TransformFuncContext::GetOneInt64Parameter(const string& name,
651 int64 default_value,
652 int64* result) const {
653 const int params_count = CountParameters(name);
654 if (params_count == 0) {
655 *result = default_value;
656 return Status::OK();
657 }
658 string string_value;
659 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
660 if (!strings::safe_strto64(StringPiece(string_value), result)) {
661 return errors::InvalidArgument("Couldn't interpret the ", name,
662 " argument as a number:", string_value);
663 }
664 return Status::OK();
665 }
666
GetOneFloatParameter(const string & name,float default_value,float * result) const667 Status TransformFuncContext::GetOneFloatParameter(const string& name,
668 float default_value,
669 float* result) const {
670 const int params_count = CountParameters(name);
671 if (params_count == 0) {
672 *result = default_value;
673 return Status::OK();
674 }
675 string string_value;
676 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
677 if (!strings::safe_strtof(string_value.c_str(), result)) {
678 return errors::InvalidArgument(
679 "Couldn't interpret the ", name,
680 " argument as a float number:", string_value);
681 }
682 return Status::OK();
683 }
684
GetOneBoolParameter(const string & name,bool default_value,bool * result) const685 Status TransformFuncContext::GetOneBoolParameter(const string& name,
686 bool default_value,
687 bool* result) const {
688 const int params_count = CountParameters(name);
689 if (params_count == 0) {
690 *result = default_value;
691 return Status::OK();
692 }
693 string string_value;
694 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
695 if (string_value == "true" || string_value == "1") {
696 *result = true;
697 } else if (string_value == "false" || string_value == "0") {
698 *result = false;
699 } else {
700 return errors::InvalidArgument("Couldn't interpret the ", name,
701 " argument as a boolean:", string_value,
702 " (expected true, false, 0 or 1)");
703 }
704 return Status::OK();
705 }
706
707 } // namespace graph_transforms
708 } // namespace tensorflow
709