1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23
24 #include "tensorflow/lite/toco/toco_port.h"
25 #include "tensorflow/lite/toco/tooling_util.h"
26 #include "tensorflow/core/platform/logging.h"
27
28 namespace toco {
29
30 namespace {
31
PrintModelStats(const string & label,const Model & model)32 void PrintModelStats(const string& label, const Model& model) {
33 int quantized_arrays = 0;
34 for (const auto& array : model.GetArrayMap()) {
35 if (array.second->quantization_params) {
36 quantized_arrays++;
37 }
38 }
39 LOG(INFO) << label << ": " << model.operators.size() << " operators, "
40 << model.GetArrayMap().size() << " arrays (" << quantized_arrays
41 << " quantized)";
42 }
43
44 // Some graphs have RNN back-edges that are discardable, having been
45 // created typically by TensorFlow import rather than specified by the user.
46 // Such graphs might have cycles (closed by RNN back-edges) that may be pruned.
47 // Local graph transformations can't identify such global features,
48 // so this function performs this global transformation.
49 //
50 // The other (and related) thing that is peculiar about RNN back-edges
51 // is that they do not prevent the arrays that they touch, from being
52 // pruned. Thus, they may refer to array names which no longer exist.
53 // The intent is for that to result in the eventual pruning of such
54 // 'dangling' RNN back-edges. We perform this pruning at the end of this
55 // function, as the pruning of connected components done here may leave
56 // more RNN back-edges dangling.
DiscardUselessConnectedComponentsAndRNNBackEdges(Model * model)57 void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
58 // Identify the set of arrays that are in 'useful' connected components
59 // of the graph, which means connected to output arrays.
60 std::unordered_set<string> useful_arrays;
61 for (const string& output_array : model->flags.output_arrays()) {
62 useful_arrays.insert(output_array);
63 }
64 bool found_new_useful_arrays;
65 do {
66 found_new_useful_arrays = false;
67 for (const auto& op : model->operators) {
68 bool op_touches_useful_arrays = false;
69 for (const string& output : op->outputs) {
70 op_touches_useful_arrays |= useful_arrays.count(output);
71 }
72 if (op_touches_useful_arrays) {
73 for (const string& input : op->inputs) {
74 found_new_useful_arrays |= !useful_arrays.count(input);
75 useful_arrays.insert(input);
76 }
77 for (const string& output : op->outputs) {
78 found_new_useful_arrays |= !useful_arrays.count(output);
79 useful_arrays.insert(output);
80 }
81 }
82 }
83 for (const auto& rnn_state : model->flags.rnn_states()) {
84 bool rnn_back_edge_touches_useful_arrays =
85 useful_arrays.count(rnn_state.state_array());
86 if (rnn_back_edge_touches_useful_arrays) {
87 found_new_useful_arrays |=
88 !useful_arrays.count(rnn_state.back_edge_source_array());
89 useful_arrays.insert(rnn_state.back_edge_source_array());
90 }
91 }
92 } while (found_new_useful_arrays);
93 // Erase arrays that aren't useful, and that are discardable.
94 model->EraseArrays([&](const string& name) {
95 return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
96 });
97 // Erase operators that do not produce a useful output array.
98 for (auto it = model->operators.begin(); it != model->operators.end();) {
99 // Only need to test the first output, as we simultaneously added all of
100 // an operator's outputs to the list of output arrays.
101 if (useful_arrays.count((*it)->outputs[0])) {
102 ++it;
103 } else {
104 for (const string& output : (*it)->outputs) {
105 CHECK(!useful_arrays.count(output));
106 }
107 it = model->operators.erase(it);
108 }
109 }
110 // Erase RNN back-edges that are 'dangling' i.e. that touch an array
111 // that no longer exists. This should only happen for discardable RNN
112 // back-edges.
113 std::vector<RnnState> rnn_states_to_keep;
114 for (const auto& rnn_state : model->flags.rnn_states()) {
115 const bool dangling =
116 !model->HasArray(rnn_state.back_edge_source_array()) ||
117 !model->HasArray(rnn_state.state_array());
118 if (dangling) {
119 CHECK(rnn_state.discardable());
120 } else {
121 rnn_states_to_keep.push_back(rnn_state);
122 }
123 }
124 model->flags.clear_rnn_states();
125 for (const auto& rnn_state : rnn_states_to_keep) {
126 *model->flags.add_rnn_states() = rnn_state;
127 }
128 }
129
GraphTransformationsPass(int increment,Model * model,const GraphTransformationsSet & transformations,tensorflow::Status * status)130 bool GraphTransformationsPass(int increment, Model* model,
131 const GraphTransformationsSet& transformations,
132 tensorflow::Status* status) {
133 CHECK(increment == 1 || increment == -1);
134 bool changed = false;
135 if (model->operators.empty()) {
136 LOG(INFO) << "Model is empty!!!";
137 return false;
138 }
139 int op_index = increment == 1 ? 0 : model->operators.size() - 1;
140 while (true) {
141 bool changed_now = false;
142 // Loop over all transformations at the current position in the graph.
143 for (const auto& transformation : transformations) {
144 CHECK(!changed_now);
145 CHECK(transformation->Messages().empty());
146 *status = transformation->Run(model, op_index, &changed_now);
147 if (!status->ok()) {
148 return false;
149 }
150 const char* made_a_change_msg =
151 changed_now ? "made a change" : "did NOT make a change";
152 const int log_level =
153 changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged;
154 if (transformation->Messages().empty()) {
155 VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
156 << " at op_index=" << op_index << "/"
157 << model->operators.size() - 1;
158 }
159 for (const string& message : transformation->Messages()) {
160 VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
161 << " at op_index=" << op_index << "/"
162 << model->operators.size() - 1 << ": " << message;
163 }
164 transformation->ClearMessages();
165 if (changed_now) {
166 DumpGraphvizVideoFrame(*model);
167 if (model->operators.empty()) return true;
168 op_index = std::min<int>(op_index, model->operators.size() - 1);
169 // Uncomment for debugging
170 // CheckInvariants(*model);
171 }
172 if (changed_now) {
173 break;
174 }
175 }
176 if (changed_now) {
177 changed = true;
178 } else {
179 const int op_index_last =
180 increment == 1 ? model->operators.size() - 1 : 0;
181 if (op_index == op_index_last) {
182 break;
183 }
184 op_index += increment;
185 }
186 }
187 DiscardUselessConnectedComponentsAndRNNBackEdges(model);
188 return changed;
189 }
190
191 } // namespace
192
RunGraphTransformationsWithStatus(Model * model,const string & msg,const GraphTransformationsSet & transformations)193 tensorflow::Status RunGraphTransformationsWithStatus(
194 Model* model, const string& msg,
195 const GraphTransformationsSet& transformations) {
196 PrintModelStats(toco::port::StringF("Before %s", msg), *model);
197 int pass_index = 0;
198 tensorflow::Status status;
199 while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
200 transformations, &status)) {
201 pass_index++;
202 const auto& label =
203 toco::port::StringF("After %s pass %d", msg, pass_index);
204 PrintModelStats(label, *model);
205 CheckInvariants(*model);
206 }
207 return status;
208 }
209
210 } // namespace toco
211