• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 #include "tensorflow/core/common_runtime/graph_constructor.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/graph/subgraph.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/platform/init_main.h"
22 #include "tensorflow/core/public/session.h"
23 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
24 #include "tensorflow/tools/graph_transforms/transform_utils.h"
25 
26 namespace tensorflow {
27 namespace graph_transforms {
28 
29 // Clears the device field of all ops in the graph.
InsertLogging(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)30 Status InsertLogging(const GraphDef& input_graph_def,
31                      const TransformFuncContext& context,
32                      GraphDef* output_graph_def) {
33   std::unordered_set<string> ops;
34   bool has_ops;
35   if (context.params.count("op")) {
36     has_ops = true;
37     for (const string& op : context.params.at("op")) {
38       ops.insert(op);
39     }
40   } else {
41     has_ops = false;
42   }
43 
44   std::unordered_set<string> prefixes;
45   bool has_prefixes;
46   if (context.params.count("prefix")) {
47     has_prefixes = true;
48     for (const string& prefix : context.params.at("prefix")) {
49       prefixes.insert(prefix);
50     }
51   } else {
52     has_prefixes = false;
53   }
54 
55   string message;
56   TF_RETURN_IF_ERROR(context.GetOneStringParameter("message", "", &message));
57 
58   bool show_name;
59   TF_RETURN_IF_ERROR(
60       context.GetOneBoolParameter("show_name", false, &show_name));
61 
62   bool show_op;
63   TF_RETURN_IF_ERROR(context.GetOneBoolParameter("show_op", false, &show_op));
64 
65   int32_t first_n;
66   TF_RETURN_IF_ERROR(context.GetOneInt32Parameter("first_n", -1, &first_n));
67 
68   int32_t summarize;
69   TF_RETURN_IF_ERROR(
70       context.GetOneInt32Parameter("summarize", 1024, &summarize));
71 
72   std::unordered_map<string, std::set<int>> node_outputs;
73   for (const NodeDef& node : input_graph_def.node()) {
74     for (const string& input : node.input()) {
75       const string canonical_input = CanonicalInputName(input);
76       string prefix;
77       string name;
78       string suffix;
79       NodeNamePartsFromInput(canonical_input, &prefix, &name, &suffix);
80       const string output_index_string = suffix.substr(1, suffix.size() - 1);
81       int32_t output_index;
82       if (!strings::safe_strto32(output_index_string, &output_index)) {
83         return errors::InvalidArgument("Couldn't understand output number in ",
84                                        input);
85       }
86       node_outputs[name].insert(output_index);
87     }
88   }
89 
90   std::map<string, string> inputs_to_rename;
91   std::unordered_set<string> ignore_when_renaming;
92   GraphDef logged_graph_def;
93   for (const NodeDef& node : input_graph_def.node()) {
94     NodeDef* new_node = logged_graph_def.mutable_node()->Add();
95     *new_node = node;
96     if (node_outputs[node.name()].empty()) {
97       // There were no outputs found to this node, so skip it.
98       continue;
99     }
100     const bool op_matches = (ops.count(node.op()) > 0);
101     bool prefix_matches = false;
102     for (const string& prefix : prefixes) {
103       if (absl::StartsWith(node.name(), prefix)) {
104         prefix_matches = true;
105       }
106     }
107     // If we're not looking for ops, or we found the right op, and if we're not
108     // looking for prefixes or we found the right prefix, then add logging here.
109     if ((!has_ops || op_matches) && (!has_prefixes || prefix_matches)) {
110       const string name_suffix = "__print__";
111       DataTypeVector input_types;
112       DataTypeVector output_types;
113       TF_RETURN_IF_ERROR(GetInOutTypes(node, &input_types, &output_types));
114       NodeDef* print_node = logged_graph_def.mutable_node()->Add();
115       print_node->set_op("Print");
116       print_node->set_name(strings::StrCat(node.name(), name_suffix));
117       string node_message;
118       if (show_op) {
119         node_message += ";" + node.op() + ";";
120       }
121       if (show_name) {
122         node_message += ";" + print_node->name() + ";";
123       }
124       node_message += message;
125       SetNodeAttr("message", node_message, print_node);
126       SetNodeAttr("first_n", first_n, print_node);
127       SetNodeAttr("summarize", summarize, print_node);
128       print_node->add_input(node.name() + ":0");
129       SetNodeAttr("T", output_types[0], print_node);
130       for (int output_index : node_outputs[node.name()]) {
131         print_node->add_input(strings::StrCat(node.name(), ":", output_index));
132       }
133       SetNodeAttr("U", output_types, print_node);
134       ignore_when_renaming.insert(print_node->name());
135       // Rewrite the graph so all references to the first input of the original
136       // op now pull from the print op instead, so it's executed.
137       inputs_to_rename[node.name() + ":0"] =
138           strings::StrCat(node.name(), name_suffix, ":0");
139     }
140   }
141 
142   output_graph_def->Clear();
143   return RenameNodeInputs(logged_graph_def, inputs_to_rename,
144                           ignore_when_renaming, output_graph_def);
145 }
146 
147 REGISTER_GRAPH_TRANSFORM("insert_logging", InsertLogging);
148 
149 }  // namespace graph_transforms
150 }  // namespace tensorflow
151