1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
17
18 #include <deque>
19 #include <string>
20 #include <vector>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25
26 namespace tflite {
27 namespace gpu {
28
Apply(const std::string & name,SequenceTransformation * transformation)29 bool ModelTransformer::Apply(const std::string& name,
30 SequenceTransformation* transformation) {
31 // Seed transformations with starting node. Each node may start a chain of
32 // transformations.
33 for (auto input : graph_->inputs()) {
34 for (auto node : graph_->FindConsumers(input->id)) {
35 AddNodeToProcess(node);
36 }
37 }
38 while (!to_process_.empty()) {
39 auto node = graph_->GetNode(to_process_.front());
40 if (node) {
41 if (!ApplyStartingWithNode(name, transformation, node)) {
42 return false;
43 }
44 }
45 to_process_.pop_front();
46 }
47 processed_.clear();
48 return true;
49 }
50
Apply(const std::string & name,NodeTransformation * transformation)51 bool ModelTransformer::Apply(const std::string& name,
52 NodeTransformation* transformation) {
53 // Apply a transformation only to nodes that are present in the graph before
54 // transformation.
55 std::vector<NodeId> nodes;
56 for (auto node : graph_->nodes()) {
57 nodes.push_back(node->id);
58 }
59 for (auto node_id : nodes) {
60 auto node = graph_->GetNode(node_id);
61 if (!node) {
62 continue;
63 }
64 auto result = transformation->ApplyToNode(node, graph_);
65 if (result.status == TransformStatus::INVALID) {
66 return false;
67 }
68 if (reporter_) {
69 if (result.status == TransformStatus::APPLIED) {
70 reporter_->AppliedTransformation(name, std::to_string(node_id),
71 result.message);
72 }
73 if (result.status == TransformStatus::DECLINED) {
74 reporter_->DeclinedTransformation(name, std::to_string(node_id),
75 result.message);
76 }
77 }
78 }
79 return true;
80 }
81
ApplyStartingWithNode(const std::string & name,SequenceTransformation * transformation,Node * begin)82 bool ModelTransformer::ApplyStartingWithNode(
83 const std::string& name, SequenceTransformation* transformation,
84 Node* begin) {
85 int expected_sequence_length = transformation->ExpectedSequenceLength();
86
87 std::deque<NodeId> sequence;
88 std::vector<Node*> nodes;
89 nodes.reserve(transformation->ExpectedSequenceLength());
90 sequence.push_back(begin->id);
91
92 // Go over nodes with sequence sliding window of size
93 // expected_sequence_length until a node with multiple dependents is found.
94 while (true) {
95 // Apply transformation if possible.
96 if (sequence.size() == expected_sequence_length) {
97 nodes.clear();
98 for (NodeId id : sequence) {
99 // Nodes present in sequence should be present in a graph. If they are
100 // not, then this transformation changes a graph but didn't say it.
101 Node* node = graph_->GetNode(id);
102 if (node == nullptr) {
103 return false;
104 }
105 nodes.push_back(node);
106 }
107
108 NodeId first_in_sequence = sequence.front();
109 auto preceding_node =
110 graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id);
111 auto result = transformation->ApplyToNodesSequence(nodes, graph_);
112 if (result.status == TransformStatus::INVALID) {
113 // graph is broken now.
114 return false;
115 }
116 if (result.status == TransformStatus::DECLINED) {
117 if (reporter_) {
118 reporter_->DeclinedTransformation(name, absl::StrJoin(sequence, "+"),
119 result.message);
120 }
121 } else if (result.status == TransformStatus::APPLIED) {
122 if (reporter_) {
123 reporter_->AppliedTransformation(name, absl::StrJoin(sequence, "+"),
124 result.message);
125 }
126 // Also remove first node of a sequence from a set of processed node.
127 // Out of all nodes in a sequence only first one may have been added
128 // to "processed" set because other nodes do not have more than one
129 // dependent. However, if a sequence is changed, then processing needs
130 // to be restarted again.
131 processed_.erase(first_in_sequence);
132 // Transformation was successful. Restart sequence from the node that
133 // precedes current sequence.
134 if (preceding_node) {
135 processed_.erase(preceding_node->id);
136 AddNodeToProcess(preceding_node);
137 } else {
138 // This is the first node in the graph. Re-seed transformation.
139 for (auto input : graph_->inputs()) {
140 for (auto node : graph_->FindConsumers(input->id)) {
141 AddNodeToProcess(node);
142 }
143 }
144 }
145 return true;
146 }
147 }
148
149 // Try to extend current sequence.
150 Node* next_node_in_sequence = nullptr;
151 bool has_multiple_children = false;
152
153 // Check that all outputs from last node are consumed by a single node.
154 for (auto output_value : graph_->FindOutputs(sequence.back())) {
155 for (auto dependent : graph_->FindConsumers(output_value->id)) {
156 if (has_multiple_children) {
157 AddNodeToProcess(dependent);
158 } else if (next_node_in_sequence == nullptr) {
159 next_node_in_sequence = dependent;
160 } else if (next_node_in_sequence != dependent) {
161 // There are more than two nodes depend on the output from end node,
162 // therefore here a sequence stops and new will start. Push all such
163 // nodes.
164 has_multiple_children = true;
165 AddNodeToProcess(dependent);
166 AddNodeToProcess(next_node_in_sequence);
167 }
168 }
169 }
170
171 // Now check that next node has inputs only produced by the last node.
172 if (!has_multiple_children && next_node_in_sequence) {
173 for (auto input : graph_->FindInputs(next_node_in_sequence->id)) {
174 auto producer = graph_->FindProducer(input->id);
175 if (producer == nullptr || producer->id != sequence.back()) {
176 has_multiple_children = true;
177 AddNodeToProcess(next_node_in_sequence);
178 break;
179 }
180 }
181 }
182
183 if (has_multiple_children || next_node_in_sequence == nullptr) {
184 // reached end of this transformation sequence.
185 return true;
186 }
187
188 sequence.push_back(next_node_in_sequence->id);
189 // Decrease sequence until it matches expected length.
190 if (sequence.size() > expected_sequence_length) {
191 sequence.pop_front();
192 }
193 }
194 return true;
195 }
196
197 } // namespace gpu
198 } // namespace tflite
199