1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/tools/graph_transforms/transform_utils.h"
17
18 #include "tensorflow/core/framework/node_def_util.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/numbers.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23
24 namespace tensorflow {
25 namespace graph_transforms {
26
27 namespace {
IsMerge(const NodeDef & node_def)28 inline bool IsMerge(const NodeDef& node_def) {
29 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
30 node_def.op() == "_XlaMerge";
31 }
32
RecordMatchedNodes(const NodeMatch & match,std::set<string> * matched_nodes)33 void RecordMatchedNodes(const NodeMatch& match,
34 std::set<string>* matched_nodes) {
35 matched_nodes->insert(match.node.name());
36 for (const NodeMatch& input_match : match.inputs) {
37 RecordMatchedNodes(input_match, matched_nodes);
38 }
39 }
40
Hash64String(const string & input)41 inline uint64 Hash64String(const string& input) {
42 return Hash64(input.data(), input.size());
43 }
44 } // namespace
45
MatchedNodesAsArray(const NodeMatch & match,std::vector<NodeDef> * result)46 void MatchedNodesAsArray(const NodeMatch& match, std::vector<NodeDef>* result) {
47 std::set<string> found_nodes;
48 std::vector<NodeMatch> current_matches = {match};
49 while (!current_matches.empty()) {
50 std::vector<NodeMatch> next_matches;
51 for (const NodeMatch& current_match : current_matches) {
52 if (found_nodes.count(current_match.node.name())) {
53 continue;
54 }
55 found_nodes.insert(current_match.node.name());
56 result->push_back(current_match.node);
57 for (const NodeMatch& input_match : current_match.inputs) {
58 next_matches.push_back(input_match);
59 }
60 }
61 current_matches = next_matches;
62 }
63 }
64
MapNamesToNodes(const GraphDef & graph_def,std::map<string,const NodeDef * > * result)65 void MapNamesToNodes(const GraphDef& graph_def,
66 std::map<string, const NodeDef*>* result) {
67 for (const NodeDef& node : graph_def.node()) {
68 (*result)[node.name()] = &node;
69 }
70 }
71
MapNodesToOutputs(const GraphDef & graph_def,std::map<string,std::vector<const NodeDef * >> * result)72 void MapNodesToOutputs(const GraphDef& graph_def,
73 std::map<string, std::vector<const NodeDef*>>* result) {
74 std::map<string, const NodeDef*> node_map;
75 MapNamesToNodes(graph_def, &node_map);
76 for (const NodeDef& node : graph_def.node()) {
77 for (const string& input : node.input()) {
78 string input_node_name = NodeNameFromInput(input);
79 (*result)[input_node_name].push_back(&node);
80 }
81 }
82 }
83
NodeNamePartsFromInput(const string & input_name,string * prefix,string * node_name,string * suffix)84 void NodeNamePartsFromInput(const string& input_name, string* prefix,
85 string* node_name, string* suffix) {
86 std::vector<string> input_parts = str_util::Split(input_name, ':');
87 if (input_parts.size() < 2) {
88 *suffix = "";
89 } else {
90 *suffix = ":" + input_parts[1];
91 }
92 StringPiece node_name_piece(input_parts[0]);
93 if (absl::ConsumePrefix(&node_name_piece, "^")) {
94 *prefix = "^";
95 } else {
96 *prefix = "";
97 }
98 *node_name = string(node_name_piece);
99 }
100
NodeNameFromInput(const string & input_name)101 string NodeNameFromInput(const string& input_name) {
102 string prefix;
103 string node_name;
104 string suffix;
105 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
106 return node_name;
107 }
108
CanonicalInputName(const string & input_name)109 string CanonicalInputName(const string& input_name) {
110 string prefix;
111 string node_name;
112 string suffix;
113 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
114 if (suffix.empty()) {
115 suffix = ":0";
116 }
117 return prefix + node_name + suffix;
118 }
119
HashNodeDef(const NodeDef & node)120 uint64 HashNodeDef(const NodeDef& node) {
121 uint64 hash = Hash64String(node.op());
122 hash = Hash64Combine(hash, Hash64String(node.name()));
123 for (const string& input : node.input()) {
124 hash = Hash64Combine(hash, Hash64String(CanonicalInputName(input)));
125 }
126 hash = Hash64Combine(hash, Hash64String(node.device()));
127 std::vector<string> attr_names;
128 attr_names.reserve(node.attr().size());
129 for (const auto& attr : node.attr()) {
130 attr_names.push_back(attr.first);
131 }
132 std::sort(attr_names.begin(), attr_names.end());
133 string attr_serialized;
134 for (const string& attr_name : attr_names) {
135 auto attr = node.attr().at(attr_name);
136 attr.SerializeToString(&attr_serialized);
137 hash = Hash64Combine(hash, Hash64String(attr_serialized));
138 }
139 return hash;
140 }
141
AddNodeInput(const string & input_name,NodeDef * node)142 void AddNodeInput(const string& input_name, NodeDef* node) {
143 *(node->mutable_input()->Add()) = input_name;
144 }
145
CopyNodeAttr(const NodeDef & source,const string & source_key,const string & dest_key,NodeDef * dest)146 void CopyNodeAttr(const NodeDef& source, const string& source_key,
147 const string& dest_key, NodeDef* dest) {
148 CHECK_NE(0, source.attr().count(source_key))
149 << "No key '" << source_key << "' found in " << source.DebugString();
150 (*(dest->mutable_attr()))[dest_key] = source.attr().at(source_key);
151 }
152
GetNodeTensorAttr(const NodeDef & node,const string & key)153 Tensor GetNodeTensorAttr(const NodeDef& node, const string& key) {
154 TensorProto tensor_proto = node.attr().at(key).tensor();
155 Tensor tensor;
156 CHECK(tensor.FromProto(tensor_proto));
157 return tensor;
158 }
159
FilterGraphDef(const GraphDef & input_graph_def,std::function<bool (const NodeDef &)> selector,GraphDef * output_graph_def)160 void FilterGraphDef(const GraphDef& input_graph_def,
161 std::function<bool(const NodeDef&)> selector,
162 GraphDef* output_graph_def) {
163 output_graph_def->mutable_node()->Clear();
164 for (const NodeDef& node : input_graph_def.node()) {
165 if (selector(node)) {
166 *output_graph_def->mutable_node()->Add() = node;
167 }
168 }
169 }
170
RemoveAttributes(const GraphDef & input_graph_def,const std::vector<string> & attributes,GraphDef * output_graph_def)171 void RemoveAttributes(const GraphDef& input_graph_def,
172 const std::vector<string>& attributes,
173 GraphDef* output_graph_def) {
174 output_graph_def->mutable_node()->Clear();
175 for (const NodeDef& node : input_graph_def.node()) {
176 NodeDef* new_node = output_graph_def->mutable_node()->Add();
177 *new_node = node;
178 for (const string& attribute : attributes) {
179 new_node->mutable_attr()->erase(attribute);
180 }
181 }
182 }
183
SortByExecutionOrder(const GraphDef & input_graph_def,GraphDef * output_graph_def)184 Status SortByExecutionOrder(const GraphDef& input_graph_def,
185 GraphDef* output_graph_def) {
186 const int num_nodes = input_graph_def.node_size();
187 std::vector<int> ready;
188 std::vector<int> pending_count;
189 pending_count.reserve(num_nodes);
190 std::vector<gtl::InlinedVector<int, 4>> outputs(num_nodes);
191
192 std::map<string, int> name_index;
193 for (int i = 0; i < input_graph_def.node_size(); ++i) {
194 const NodeDef& node(input_graph_def.node(i));
195 name_index[node.name()] = i;
196 }
197
198 // Parse the inputs for each node.
199 for (int n = 0; n < num_nodes; ++n) {
200 const NodeDef& node_def(input_graph_def.node(n));
201 if (IsMerge(node_def)) {
202 // for merge only wait for one non-control input.
203 int32 num_control_edges = 0;
204 for (int i = 0; i < node_def.input_size(); ++i) {
205 if (absl::StartsWith(node_def.input(i), "^")) {
206 num_control_edges++;
207 }
208 }
209 pending_count.push_back(num_control_edges + 1);
210 } else {
211 pending_count.push_back(node_def.input_size());
212 }
213 if (node_def.input_size() == 0) {
214 ready.push_back(n);
215 continue;
216 }
217 for (int i = 0; i < node_def.input_size(); ++i) {
218 const string& input_name = node_def.input(i);
219 const string& input_node_name = NodeNameFromInput(input_name);
220 if (!name_index.count(input_node_name)) {
221 return errors::InvalidArgument("Node '", node_def.name(),
222 "': Unknown input node '",
223 node_def.input(i), "'");
224 }
225 outputs[name_index[input_node_name]].push_back(n);
226 }
227 }
228
229 int processed = 0;
230 output_graph_def->Clear();
231 // Process the NodeDefs in topological order.
232 // Code above sets this up by filling in ready_ with nodes that have no
233 // inputs, pending_counts_ with the number of inputs for each node and
234 // outputs_ with the outputs of each node.
235 while (!ready.empty()) {
236 int o = ready.back();
237 ready.pop_back();
238 ++processed;
239 const NodeDef& node_def(input_graph_def.node(o));
240 *output_graph_def->mutable_node()->Add() = node_def;
241
242 // Update pending_count for outputs.
243 for (size_t i = 0; i < outputs[o].size(); ++i) {
244 const int output = outputs[o][i];
245 pending_count[output]--;
246 if (pending_count[output] == 0) {
247 ready.push_back(output);
248 }
249 }
250 }
251
252 if (processed < num_nodes) {
253 LOG(WARNING) << "IN " << __func__ << (num_nodes - processed)
254 << " NODES IN A CYCLE";
255 for (int64 i = 0; i < num_nodes; i++) {
256 if (pending_count[i] != 0) {
257 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(input_graph_def.node(i))
258 << "WITH PENDING COUNT = " << pending_count[i];
259 }
260 }
261 return errors::InvalidArgument(num_nodes - processed, " nodes in a cycle");
262 }
263 return Status::OK();
264 }
265
DebugString() const266 string OpTypePattern::DebugString() const {
267 string result = "{" + op + ", {";
268 for (const OpTypePattern& input : inputs) {
269 result += input.DebugString() + ",";
270 }
271 result += "}}";
272 return result;
273 }
274
DebugString() const275 string NodeMatch::DebugString() const {
276 string result = "{";
277 result += node.DebugString();
278 result += ", {";
279 for (const NodeMatch& input : inputs) {
280 result += input.DebugString() + ",";
281 }
282 result += "}}";
283 return result;
284 }
285
GraphMatcher(const GraphDef & graph_def)286 GraphMatcher::GraphMatcher(const GraphDef& graph_def) {
287 SortByExecutionOrder(graph_def, &graph_def_).IgnoreError();
288 MapNamesToNodes(graph_def_, &node_map_);
289 }
290
GetOpTypeMatches(const OpTypePattern & pattern,std::vector<NodeMatch> * matches)291 Status GraphMatcher::GetOpTypeMatches(const OpTypePattern& pattern,
292 std::vector<NodeMatch>* matches) {
293 std::set<string> matched_nodes;
294 for (const NodeDef& node : graph_def_.node()) {
295 // Skip any nodes that are already part of a match.
296 if (matched_nodes.count(node.name())) {
297 continue;
298 }
299 NodeMatch match;
300 if (DoesOpTypeMatch(node, pattern, matched_nodes, &match)) {
301 RecordMatchedNodes(match, &matched_nodes);
302 matches->push_back(match);
303 }
304 }
305 return Status::OK();
306 }
307
DoesOpTypeMatch(const NodeDef & node,const OpTypePattern & pattern,const std::set<string> & previously_matched_nodes,NodeMatch * match)308 bool GraphMatcher::DoesOpTypeMatch(
309 const NodeDef& node, const OpTypePattern& pattern,
310 const std::set<string>& previously_matched_nodes, NodeMatch* match) {
311 VLOG(1) << "Looking at node " << node.DebugString();
312 VLOG(1) << "pattern=" << pattern.DebugString();
313 VLOG(1) << "match=" << match->DebugString();
314 if (previously_matched_nodes.count(node.name())) {
315 VLOG(1) << "node " << node.name() << " has been previously matched";
316 return false;
317 }
318 bool pattern_matched = false;
319 if (pattern.op == "*") {
320 pattern_matched = true;
321 } else {
322 std::vector<string> pattern_ops = str_util::Split(pattern.op, '|');
323 for (const string& pattern_op : pattern_ops) {
324 if (node.op() == pattern_op) {
325 pattern_matched = true;
326 }
327 }
328 }
329 if (!pattern_matched) {
330 VLOG(1) << "node.op() != pattern.op()";
331 return false;
332 }
333 match->node = node;
334 // Ignore any control inputs for pattern-matching purposes
335 std::vector<string> non_control_inputs;
336 for (const string& input : node.input()) {
337 if (!input.empty() && (input[0] != '^')) {
338 non_control_inputs.push_back(input);
339 }
340 }
341 if (pattern.inputs.empty()) {
342 // If there are no inputs, assume that's the end of the pattern.
343 return true;
344 }
345 if (non_control_inputs.size() != pattern.inputs.size()) {
346 VLOG(1) << "non_control_inputs.size() != pattern.inputs.size()";
347 return false;
348 }
349 for (int i = 0; i < pattern.inputs.size(); ++i) {
350 const string& input_node_name = NodeNameFromInput(non_control_inputs[i]);
351 const NodeDef& input_node = *(node_map_[input_node_name]);
352 const OpTypePattern& input_pattern = pattern.inputs[i];
353 match->inputs.push_back(NodeMatch());
354 NodeMatch* input_match = &(match->inputs.back());
355 if (!DoesOpTypeMatch(input_node, input_pattern, previously_matched_nodes,
356 input_match)) {
357 return false;
358 }
359 }
360 return true;
361 }
362
ReplaceMatchingOpTypes(const GraphDef & input_graph_def,const OpTypePattern & pattern,const std::function<Status (const NodeMatch &,const std::set<string> &,const std::set<string> &,std::vector<NodeDef> *)> & node_generator,const ReplaceMatchingOpTypesOptions & options,GraphDef * output_graph_def)363 Status ReplaceMatchingOpTypes(
364 const GraphDef& input_graph_def, const OpTypePattern& pattern,
365 const std::function<Status(const NodeMatch&, const std::set<string>&,
366 const std::set<string>&, std::vector<NodeDef>*)>&
367 node_generator,
368 const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def) {
369 // Start off by retrieving all the matching subgraphs.
370 GraphMatcher matcher(input_graph_def);
371 std::vector<NodeMatch> matches;
372 TF_RETURN_IF_ERROR(matcher.GetOpTypeMatches(pattern, &matches));
373
374 // Do some housekeeping so we can easily look up the resulting matches given
375 // a node name.
376 std::set<string> matched_nodes;
377 std::map<string, const NodeMatch*> matches_by_head_name;
378 for (const NodeMatch& match : matches) {
379 matches_by_head_name[match.node.name()] = &match;
380 RecordMatchedNodes(match, &matched_nodes);
381 }
382 std::map<string, std::vector<const NodeDef*>> outputs_map;
383 MapNodesToOutputs(input_graph_def, &outputs_map);
384
385 // Go through all the nodes in the input graph, see if they are part of a
386 // match or if they can be left untouched.
387 output_graph_def->Clear();
388 for (const NodeDef& input_node : input_graph_def.node()) {
389 if (matches_by_head_name.count(input_node.name())) {
390 // This node is the beginning of a match, so call the replacement function
391 // after setting up some information it will need.
392 const NodeMatch* match = matches_by_head_name[input_node.name()];
393 std::vector<NodeDef> matched_nodes_array;
394 MatchedNodesAsArray(*match, &matched_nodes_array);
395 // This tells us whether a node is part of the current match.
396 std::set<string> matched_nodes_lookup;
397 for (const NodeDef& matched_node : matched_nodes_array) {
398 matched_nodes_lookup.insert(matched_node.name());
399 }
400 // These are helper arrays that the replacement function can use to tell
401 // whether it can safely remove an internal node (because nothing outside
402 // of the match uses it) or whether external nodes depend on it.
403 std::set<string> input_nodes;
404 std::set<string> output_nodes;
405 for (const NodeDef& matched_node : matched_nodes_array) {
406 // Look through all of this node's inputs, and if any of them come from
407 // outside the match, then this should be noted as one of the external
408 // inputs of the subgraph.
409 for (const string& input_name : matched_node.input()) {
410 string input_node_name = NodeNameFromInput(input_name);
411 if (!matched_nodes_lookup.count(input_node_name)) {
412 input_nodes.insert(matched_node.name());
413 }
414 }
415 // Do a reverse input lookup, to see which other nodes use the current
416 // one as an input. If any of those nodes are outside the match
417 // subgraph, then the current node is marked as an output node that
418 // shouldn't be removed.
419 if (outputs_map.count(matched_node.name())) {
420 for (const NodeDef* dependent_node :
421 outputs_map[matched_node.name()]) {
422 if (!matched_nodes_lookup.count(dependent_node->name())) {
423 output_nodes.insert(matched_node.name());
424 }
425 }
426 }
427 }
428 // Call the generator function and add all the returned nodes to the
429 // graph.
430 std::vector<NodeDef> new_nodes;
431 TF_RETURN_IF_ERROR(
432 node_generator(*match, input_nodes, output_nodes, &new_nodes));
433 std::set<string> new_node_names;
434 for (const NodeDef& new_node : new_nodes) {
435 new_node_names.insert(new_node.name());
436 }
437 // Check to make sure the generator function preserved all of the nodes
438 // that are used elsewhere in the graph, and add them back in if not.
439 bool abort_replacement = false;
440 if (!options.allow_inconsistencies) {
441 for (const string& expected_output : output_nodes) {
442 if (!new_node_names.count(expected_output)) {
443 LOG(WARNING) << "Expected " << expected_output
444 << " to be preserved.";
445 abort_replacement = true;
446 }
447 }
448 }
449 if (abort_replacement) {
450 LOG(WARNING) << "Generator function didn't preserve needed nodes, "
451 << "copying old replacements back in instead.";
452 std::vector<NodeDef> old_nodes;
453 MatchedNodesAsArray(*match, &old_nodes);
454 for (const NodeDef& old_node : old_nodes) {
455 NodeDef* added_node = output_graph_def->mutable_node()->Add();
456 *added_node = old_node;
457 }
458 } else {
459 for (const NodeDef& new_node : new_nodes) {
460 NodeDef* added_node = output_graph_def->mutable_node()->Add();
461 *added_node = new_node;
462 }
463 }
464 } else if (!matched_nodes.count(input_node.name())) {
465 // This node isn't part of any match, so just copy it over.
466 NodeDef* added_node = output_graph_def->mutable_node()->Add();
467 *added_node = input_node;
468 } else {
469 // Do nothing, because this is an internal part of a matching subgraph,
470 // and so will have been replaced by a new replacement subgraph.
471 }
472 }
473
474 return Status::OK();
475 }
476
RenameNodeInputs(const GraphDef & input_graph_def,const std::map<string,string> & inputs_to_rename,const std::unordered_set<string> & nodes_to_ignore,GraphDef * output_graph_def)477 Status RenameNodeInputs(const GraphDef& input_graph_def,
478 const std::map<string, string>& inputs_to_rename,
479 const std::unordered_set<string>& nodes_to_ignore,
480 GraphDef* output_graph_def) {
481 std::map<string, std::vector<std::pair<string, string>>>
482 canonical_inputs_to_rename;
483 for (const auto& input_to_rename : inputs_to_rename) {
484 canonical_inputs_to_rename[NodeNameFromInput(input_to_rename.first)]
485 .push_back({input_to_rename.first, input_to_rename.second});
486 }
487
488 output_graph_def->Clear();
489 for (const NodeDef& node : input_graph_def.node()) {
490 NodeDef* new_node = output_graph_def->mutable_node()->Add();
491 *new_node = node;
492 new_node->mutable_input()->Clear();
493 for (const string& input_name : node.input()) {
494 std::set<string> already_visited;
495 string new_input_name = input_name;
496 while (
497 canonical_inputs_to_rename.count(NodeNameFromInput(new_input_name))) {
498 string input_node_name = NodeNameFromInput(new_input_name);
499 if (already_visited.count(input_node_name)) {
500 return errors::InvalidArgument(
501 "RenameNodeInputs argument contains a cycle for ",
502 input_node_name);
503 }
504 already_visited.insert(input_node_name);
505 if (nodes_to_ignore.count(node.name())) {
506 break;
507 }
508 bool any_match_found = false;
509 for (const std::pair<string, string>& input_to_rename :
510 canonical_inputs_to_rename.at(input_node_name)) {
511 const string& source_name = input_to_rename.first;
512 const string& dest_name = input_to_rename.second;
513 bool is_match;
514 string match_name;
515 if (str_util::EndsWith(source_name, ":*")) {
516 is_match = true;
517 string prefix;
518 string unused_node_name;
519 string suffix;
520 NodeNamePartsFromInput(new_input_name, &prefix, &unused_node_name,
521 &suffix);
522 match_name = prefix + dest_name + suffix;
523 } else {
524 is_match = (CanonicalInputName(source_name) ==
525 CanonicalInputName(new_input_name));
526 match_name = dest_name;
527 }
528 if (is_match) {
529 new_input_name = match_name;
530 any_match_found = true;
531 }
532 }
533 if (!any_match_found) {
534 break;
535 }
536 }
537 *(new_node->mutable_input()->Add()) = new_input_name;
538 }
539 }
540 return Status::OK();
541 }
542
CopyOriginalMatch(const NodeMatch & match,std::vector<NodeDef> * new_nodes)543 void CopyOriginalMatch(const NodeMatch& match,
544 std::vector<NodeDef>* new_nodes) {
545 std::vector<NodeDef> old_nodes;
546 MatchedNodesAsArray(match, &old_nodes);
547 for (const NodeDef& old_node : old_nodes) {
548 new_nodes->push_back(old_node);
549 }
550 }
551
GetTransformRegistry()552 TransformRegistry* GetTransformRegistry() {
553 static TransformRegistry transform_registry;
554 return &transform_registry;
555 }
556
FindInvalidInputs(const GraphDef & graph_def,std::vector<std::pair<string,string>> * invalid_inputs)557 void FindInvalidInputs(const GraphDef& graph_def,
558 std::vector<std::pair<string, string>>* invalid_inputs) {
559 std::map<string, const NodeDef*> node_map;
560 MapNamesToNodes(graph_def, &node_map);
561
562 for (const NodeDef& node : graph_def.node()) {
563 for (const string& input : node.input()) {
564 string input_node = NodeNameFromInput(input);
565 if (!node_map.count(input_node)) {
566 invalid_inputs->push_back({node.name(), input_node});
567 }
568 }
569 }
570 }
571
IsGraphValid(const GraphDef & graph_def)572 Status IsGraphValid(const GraphDef& graph_def) {
573 std::vector<std::pair<string, string>> invalid_inputs;
574 FindInvalidInputs(graph_def, &invalid_inputs);
575 if (!invalid_inputs.empty()) {
576 std::map<string, const NodeDef*> node_map;
577 MapNamesToNodes(graph_def, &node_map);
578 for (const std::pair<string, string>& invalid_input : invalid_inputs) {
579 LOG(ERROR) << "Invalid input " << invalid_input.second << " for node "
580 << invalid_input.first << " - "
581 << node_map[invalid_input.first]->DebugString();
582 }
583 return errors::Internal(
584 "Invalid graph with inputs referring to nonexistent nodes");
585 }
586 return Status::OK();
587 }
588
GetInOutTypes(const NodeDef & node_def,DataTypeVector * inputs,DataTypeVector * outputs)589 Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
590 DataTypeVector* outputs) {
591 const OpDef* op_def;
592 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def));
593 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, *op_def, inputs, outputs));
594 return Status::OK();
595 }
596
TensorShapeFromString(const string & shape_string,TensorShape * result)597 Status TensorShapeFromString(const string& shape_string, TensorShape* result) {
598 if (shape_string.empty()) {
599 return errors::InvalidArgument("Specified shape is empty.");
600 }
601 std::vector<string> dims_as_str = str_util::Split(shape_string, ",");
602 std::vector<int64> dims;
603 for (const string& dim : dims_as_str) {
604 int64 tmp;
605 if (strings::safe_strto64(dim, &tmp)) {
606 dims.push_back(tmp);
607 } else {
608 return errors::InvalidArgument("Could parse as shape: '", shape_string,
609 "'");
610 }
611 }
612 *result = TensorShape(dims);
613 return Status::OK();
614 }
615
CountParameters(const string & name) const616 int TransformFuncContext::CountParameters(const string& name) const {
617 if (params.count(name)) {
618 return params.at(name).size();
619 } else {
620 return 0;
621 }
622 }
623
GetOneStringParameter(const string & name,const string & default_value,string * result) const624 Status TransformFuncContext::GetOneStringParameter(const string& name,
625 const string& default_value,
626 string* result) const {
627 const int params_count = CountParameters(name);
628 if (params_count == 0) {
629 *result = default_value;
630 return Status::OK();
631 } else if (params_count == 1) {
632 *result = params.at(name).at(0);
633 return Status::OK();
634 } else {
635 return errors::InvalidArgument("Expected a single '", name,
636 "' parameter, but found ", params_count,
637 " occurrences");
638 }
639 }
640
GetOneInt32Parameter(const string & name,int32 default_value,int32 * result) const641 Status TransformFuncContext::GetOneInt32Parameter(const string& name,
642 int32 default_value,
643 int32* result) const {
644 const int params_count = CountParameters(name);
645 if (params_count == 0) {
646 *result = default_value;
647 return Status::OK();
648 }
649 string string_value;
650 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
651 if (!strings::safe_strto32(StringPiece(string_value), result)) {
652 return errors::InvalidArgument("Couldn't interpret the ", name,
653 " argument as a number:", string_value);
654 }
655 return Status::OK();
656 }
657
GetOneInt64Parameter(const string & name,int64 default_value,int64 * result) const658 Status TransformFuncContext::GetOneInt64Parameter(const string& name,
659 int64 default_value,
660 int64* result) const {
661 const int params_count = CountParameters(name);
662 if (params_count == 0) {
663 *result = default_value;
664 return Status::OK();
665 }
666 string string_value;
667 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
668 if (!strings::safe_strto64(StringPiece(string_value), result)) {
669 return errors::InvalidArgument("Couldn't interpret the ", name,
670 " argument as a number:", string_value);
671 }
672 return Status::OK();
673 }
674
GetOneFloatParameter(const string & name,float default_value,float * result) const675 Status TransformFuncContext::GetOneFloatParameter(const string& name,
676 float default_value,
677 float* result) const {
678 const int params_count = CountParameters(name);
679 if (params_count == 0) {
680 *result = default_value;
681 return Status::OK();
682 }
683 string string_value;
684 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
685 if (!strings::safe_strtof(string_value.c_str(), result)) {
686 return errors::InvalidArgument(
687 "Couldn't interpret the ", name,
688 " argument as a float number:", string_value);
689 }
690 return Status::OK();
691 }
692
GetOneBoolParameter(const string & name,bool default_value,bool * result) const693 Status TransformFuncContext::GetOneBoolParameter(const string& name,
694 bool default_value,
695 bool* result) const {
696 const int params_count = CountParameters(name);
697 if (params_count == 0) {
698 *result = default_value;
699 return Status::OK();
700 }
701 string string_value;
702 TF_RETURN_IF_ERROR(GetOneStringParameter(name, "", &string_value));
703 if (string_value == "true" || string_value == "1") {
704 *result = true;
705 } else if (string_value == "false" || string_value == "0") {
706 *result = false;
707 } else {
708 return errors::InvalidArgument("Couldn't interpret the ", name,
709 " argument as a boolean:", string_value,
710 " (expected true, false, 0 or 1)");
711 }
712 return Status::OK();
713 }
714
715 } // namespace graph_transforms
716 } // namespace tensorflow
717