• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
17 
18 #include <deque>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25 
26 namespace tflite {
27 namespace gpu {
28 
Apply(const std::string & name,SequenceTransformation * transformation)29 bool ModelTransformer::Apply(const std::string& name,
30                              SequenceTransformation* transformation) {
31   // Seed transformations with starting node. Each node may start a chain of
32   // transformations.
33   for (auto input : graph_->inputs()) {
34     for (auto node : graph_->FindConsumers(input->id)) {
35       AddNodeToProcess(node);
36     }
37   }
38   while (!to_process_.empty()) {
39     auto node = graph_->GetNode(to_process_.front());
40     if (node) {
41       if (!ApplyStartingWithNode(name, transformation, node)) {
42         return false;
43       }
44     }
45     to_process_.pop_front();
46   }
47   processed_.clear();
48   return true;
49 }
50 
Apply(const std::string & name,NodeTransformation * transformation)51 bool ModelTransformer::Apply(const std::string& name,
52                              NodeTransformation* transformation) {
53   // Apply a transformation only to nodes that are present in the graph before
54   // transformation.
55   std::vector<NodeId> nodes;
56   for (auto node : graph_->nodes()) {
57     nodes.push_back(node->id);
58   }
59   for (auto node_id : nodes) {
60     auto node = graph_->GetNode(node_id);
61     if (!node) {
62       continue;
63     }
64     auto result = transformation->ApplyToNode(node, graph_);
65     if (result.status == TransformStatus::INVALID) {
66       return false;
67     }
68     if (reporter_) {
69       if (result.status == TransformStatus::APPLIED) {
70         reporter_->AppliedTransformation(name, std::to_string(node_id),
71                                          result.message);
72       }
73       if (result.status == TransformStatus::DECLINED) {
74         reporter_->DeclinedTransformation(name, std::to_string(node_id),
75                                           result.message);
76       }
77     }
78   }
79   return true;
80 }
81 
ApplyStartingWithNode(const std::string & name,SequenceTransformation * transformation,Node * begin)82 bool ModelTransformer::ApplyStartingWithNode(
83     const std::string& name, SequenceTransformation* transformation,
84     Node* begin) {
85   int expected_sequence_length = transformation->ExpectedSequenceLength();
86 
87   std::deque<NodeId> sequence;
88   std::vector<Node*> nodes;
89   nodes.reserve(transformation->ExpectedSequenceLength());
90   sequence.push_back(begin->id);
91 
92   // Go over nodes with sequence sliding window of size
93   // expected_sequence_length until a node with multiple dependents is found.
94   while (true) {
95     // Apply transformation if possible.
96     if (sequence.size() == expected_sequence_length) {
97       nodes.clear();
98       for (NodeId id : sequence) {
99         // Nodes present in sequence should be present in a graph. If they are
100         // not, then this transformation changes a graph but didn't say it.
101         Node* node = graph_->GetNode(id);
102         if (node == nullptr) {
103           return false;
104         }
105         nodes.push_back(node);
106       }
107 
108       NodeId first_in_sequence = sequence.front();
109       auto preceding_node =
110           graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id);
111       auto result = transformation->ApplyToNodesSequence(nodes, graph_);
112       if (result.status == TransformStatus::INVALID) {
113         // graph is broken now.
114         return false;
115       }
116       if (result.status == TransformStatus::DECLINED) {
117         if (reporter_) {
118           reporter_->DeclinedTransformation(name, absl::StrJoin(sequence, "+"),
119                                             result.message);
120         }
121       } else if (result.status == TransformStatus::APPLIED) {
122         if (reporter_) {
123           reporter_->AppliedTransformation(name, absl::StrJoin(sequence, "+"),
124                                            result.message);
125         }
126         // Also remove first node of a sequence from a set of processed node.
127         // Out of all nodes in a sequence only first one may have been added
128         // to "processed" set because other nodes do not have more than one
129         // dependent. However, if a sequence is changed, then processing needs
130         // to be restarted again.
131         processed_.erase(first_in_sequence);
132         // Transformation was successful. Restart sequence from the node that
133         // precedes current sequence.
134         if (preceding_node) {
135           processed_.erase(preceding_node->id);
136           AddNodeToProcess(preceding_node);
137         } else {
138           // This is the first node in the graph. Re-seed transformation.
139           for (auto input : graph_->inputs()) {
140             for (auto node : graph_->FindConsumers(input->id)) {
141               AddNodeToProcess(node);
142             }
143           }
144         }
145         return true;
146       }
147     }
148 
149     // Try to extend current sequence.
150     Node* next_node_in_sequence = nullptr;
151     bool has_multiple_children = false;
152 
153     // Check that all outputs from last node are consumed by a single node.
154     for (auto output_value : graph_->FindOutputs(sequence.back())) {
155       for (auto dependent : graph_->FindConsumers(output_value->id)) {
156         if (has_multiple_children) {
157           AddNodeToProcess(dependent);
158         } else if (next_node_in_sequence == nullptr) {
159           next_node_in_sequence = dependent;
160         } else if (next_node_in_sequence != dependent) {
161           // There are more than two nodes depend on the output from end node,
162           // therefore here a sequence stops and new will start. Push all such
163           // nodes.
164           has_multiple_children = true;
165           AddNodeToProcess(dependent);
166           AddNodeToProcess(next_node_in_sequence);
167         }
168       }
169     }
170 
171     // Now check that next node has inputs only produced by the last node.
172     if (!has_multiple_children && next_node_in_sequence) {
173       for (auto input : graph_->FindInputs(next_node_in_sequence->id)) {
174         auto producer = graph_->FindProducer(input->id);
175         if (producer == nullptr || producer->id != sequence.back()) {
176           has_multiple_children = true;
177           AddNodeToProcess(next_node_in_sequence);
178           break;
179         }
180       }
181     }
182 
183     if (has_multiple_children || next_node_in_sequence == nullptr) {
184       // reached end of this transformation sequence.
185       return true;
186     }
187 
188     sequence.push_back(next_node_in_sequence->id);
189     // Decrease sequence until it matches expected length.
190     if (sequence.size() > expected_sequence_length) {
191       sequence.pop_front();
192     }
193   }
194   return true;
195 }
196 
197 }  // namespace gpu
198 }  // namespace tflite
199