• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/lite/toco/toco_port.h"
25 #include "tensorflow/lite/toco/tooling_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace toco {
29 
30 namespace {
31 
PrintModelStats(const string & label,const Model & model)32 void PrintModelStats(const string& label, const Model& model) {
33   int quantized_arrays = 0;
34   for (const auto& array : model.GetArrayMap()) {
35     if (array.second->quantization_params) {
36       quantized_arrays++;
37     }
38   }
39   LOG(INFO) << label << ": " << model.operators.size() << " operators, "
40             << model.GetArrayMap().size() << " arrays (" << quantized_arrays
41             << " quantized)";
42 }
43 
44 // Some graphs have RNN back-edges that are discardable, having been
45 // created typically by TensorFlow import rather than specified by the user.
46 // Such graphs might have cycles (closed by RNN back-edges) that may be pruned.
47 // Local graph transformations can't identify such global features,
48 // so this function performs this global transformation.
49 //
50 // The other (and related) thing that is peculiar about RNN back-edges
51 // is that they do not prevent the arrays that they touch, from being
52 // pruned. Thus, they may refer to array names which no longer exist.
53 // The intent is for that to result in the eventual pruning of such
54 // 'dangling' RNN back-edges. We perform this pruning at the end of this
55 // function, as the pruning of connected components done here may leave
56 // more RNN back-edges dangling.
DiscardUselessConnectedComponentsAndRNNBackEdges(Model * model)57 void DiscardUselessConnectedComponentsAndRNNBackEdges(Model* model) {
58   // Identify the set of arrays that are in 'useful' connected components
59   // of the graph, which means connected to output arrays.
60   std::unordered_set<string> useful_arrays;
61   for (const string& output_array : model->flags.output_arrays()) {
62     useful_arrays.insert(output_array);
63   }
64   bool found_new_useful_arrays;
65   do {
66     found_new_useful_arrays = false;
67     for (const auto& op : model->operators) {
68       bool op_touches_useful_arrays = false;
69       for (const string& output : op->outputs) {
70         op_touches_useful_arrays |= useful_arrays.count(output);
71       }
72       if (op_touches_useful_arrays) {
73         for (const string& input : op->inputs) {
74           found_new_useful_arrays |= !useful_arrays.count(input);
75           useful_arrays.insert(input);
76         }
77         for (const string& output : op->outputs) {
78           found_new_useful_arrays |= !useful_arrays.count(output);
79           useful_arrays.insert(output);
80         }
81       }
82     }
83     for (const auto& rnn_state : model->flags.rnn_states()) {
84       bool rnn_back_edge_touches_useful_arrays =
85           useful_arrays.count(rnn_state.state_array());
86       if (rnn_back_edge_touches_useful_arrays) {
87         found_new_useful_arrays |=
88             !useful_arrays.count(rnn_state.back_edge_source_array());
89         useful_arrays.insert(rnn_state.back_edge_source_array());
90       }
91     }
92   } while (found_new_useful_arrays);
93   // Erase arrays that aren't useful, and that are discardable.
94   model->EraseArrays([&](const string& name) {
95     return (!useful_arrays.count(name) && IsDiscardableArray(*model, name));
96   });
97   // Erase operators that do not produce a useful output array.
98   for (auto it = model->operators.begin(); it != model->operators.end();) {
99     // Only need to test the first output, as we simultaneously added all of
100     // an operator's outputs to the list of output arrays.
101     if (useful_arrays.count((*it)->outputs[0])) {
102       ++it;
103     } else {
104       for (const string& output : (*it)->outputs) {
105         CHECK(!useful_arrays.count(output));
106       }
107       it = model->operators.erase(it);
108     }
109   }
110   // Erase RNN back-edges that are 'dangling' i.e. that touch an array
111   // that no longer exists. This should only happen for discardable RNN
112   // back-edges.
113   std::vector<RnnState> rnn_states_to_keep;
114   for (const auto& rnn_state : model->flags.rnn_states()) {
115     const bool dangling =
116         !model->HasArray(rnn_state.back_edge_source_array()) ||
117         !model->HasArray(rnn_state.state_array());
118     if (dangling) {
119       CHECK(rnn_state.discardable());
120     } else {
121       rnn_states_to_keep.push_back(rnn_state);
122     }
123   }
124   model->flags.clear_rnn_states();
125   for (const auto& rnn_state : rnn_states_to_keep) {
126     *model->flags.add_rnn_states() = rnn_state;
127   }
128 }
129 
GraphTransformationsPass(int increment,Model * model,const GraphTransformationsSet & transformations,tensorflow::Status * status)130 bool GraphTransformationsPass(int increment, Model* model,
131                               const GraphTransformationsSet& transformations,
132                               tensorflow::Status* status) {
133   CHECK(increment == 1 || increment == -1);
134   bool changed = false;
135   if (model->operators.empty()) {
136     LOG(INFO) << "Model is empty!!!";
137     return false;
138   }
139   int op_index = increment == 1 ? 0 : model->operators.size() - 1;
140   while (true) {
141     bool changed_now = false;
142     // Loop over all transformations at the current position in the graph.
143     for (const auto& transformation : transformations) {
144       CHECK(!changed_now);
145       CHECK(transformation->Messages().empty());
146       *status = transformation->Run(model, op_index, &changed_now);
147       if (!status->ok()) {
148         return false;
149       }
150       const char* made_a_change_msg =
151           changed_now ? "made a change" : "did NOT make a change";
152       const int log_level =
153           changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged;
154       if (transformation->Messages().empty()) {
155         VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
156                         << " at op_index=" << op_index << "/"
157                         << model->operators.size() - 1;
158       }
159       for (const string& message : transformation->Messages()) {
160         VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
161                         << " at op_index=" << op_index << "/"
162                         << model->operators.size() - 1 << ": " << message;
163       }
164       transformation->ClearMessages();
165       if (changed_now) {
166         DumpGraphvizVideoFrame(*model);
167         if (model->operators.empty()) return true;
168         op_index = std::min<int>(op_index, model->operators.size() - 1);
169         // Uncomment for debugging
170         // CheckInvariants(*model);
171       }
172       if (changed_now) {
173         break;
174       }
175     }
176     if (changed_now) {
177       changed = true;
178     } else {
179       const int op_index_last =
180           increment == 1 ? model->operators.size() - 1 : 0;
181       if (op_index == op_index_last) {
182         break;
183       }
184       op_index += increment;
185     }
186   }
187   DiscardUselessConnectedComponentsAndRNNBackEdges(model);
188   return changed;
189 }
190 
191 }  // namespace
192 
RunGraphTransformationsWithStatus(Model * model,const string & msg,const GraphTransformationsSet & transformations)193 tensorflow::Status RunGraphTransformationsWithStatus(
194     Model* model, const string& msg,
195     const GraphTransformationsSet& transformations) {
196   PrintModelStats(toco::port::StringF("Before %s", msg), *model);
197   int pass_index = 0;
198   tensorflow::Status status;
199   while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
200                                   transformations, &status)) {
201     pass_index++;
202     const auto& label =
203         toco::port::StringF("After %s pass %d", msg, pass_index);
204     PrintModelStats(label, *model);
205     CheckInvariants(*model);
206   }
207   return status;
208 }
209 
210 }  // namespace toco
211