1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
17
18 #include <algorithm>
19 #include <any>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
31 #include "tensorflow/lite/delegates/gpu/common/model.h"
32 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
33 #include "tensorflow/lite/delegates/gpu/common/operations.h"
34 #include "tensorflow/lite/delegates/gpu/common/shape.h"
35 #include "tensorflow/lite/delegates/gpu/common/status.h"
36 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
37
38 namespace tflite {
39 namespace gpu {
40 namespace {
41
42 using ShouldRemoveOperation = std::function<bool(GraphFloat32* graph, Node*)>;
43
44 class RemoveOperation : public SequenceTransformation {
45 public:
RemoveOperation(ShouldRemoveOperation remove_predicate)46 explicit RemoveOperation(ShouldRemoveOperation remove_predicate)
47 : remove_predicate_(std::move(remove_predicate)) {}
48
ExpectedSequenceLength() const49 int ExpectedSequenceLength() const final { return 2; }
50
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)51 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
52 GraphFloat32* graph) final {
53 Node* prev_op_node = sequence.front();
54 Node* op_node = sequence.back();
55 if (!remove_predicate_(graph, op_node)) {
56 return {TransformStatus::SKIPPED, ""};
57 }
58 absl::Status status = RemoveFollowingNode(graph, op_node, prev_op_node);
59 if (!status.ok()) {
60 return {TransformStatus::INVALID,
61 "Unable to remove a node: " + std::string(status.message())};
62 }
63 return {TransformStatus::APPLIED, ""};
64 }
65
66 private:
67 ShouldRemoveOperation remove_predicate_;
68 };
69
70 } // namespace
71
NewRemoveSingleInputConcat()72 std::unique_ptr<SequenceTransformation> NewRemoveSingleInputConcat() {
73 // Using SequenceTransformation implies that CONCAT has a single input.
74 auto type = ToString(OperationType::CONCAT);
75 return absl::make_unique<RemoveOperation>(
76 [type](GraphFloat32* graph, Node* node) {
77 return type == node->operation.type;
78 });
79 }
80
NewRemoveSingleInputAdd()81 std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
82 // Using SequenceTransformation implies that ADD has a single input.
83 auto type = ToString(OperationType::ADD);
84 return absl::make_unique<RemoveOperation>(
85 [type](GraphFloat32* graph, Node* node) {
86 if (node->operation.type != type) {
87 return false;
88 }
89 auto& attr = absl::any_cast<const ElementwiseAttributes&>(
90 node->operation.attributes);
91 return !absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(
92 attr.param) &&
93 !absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
94 attr.param) &&
95 !absl::holds_alternative<float>(attr.param);
96 });
97 }
98
NewRemoveDegenerateUpsampling()99 std::unique_ptr<SequenceTransformation> NewRemoveDegenerateUpsampling() {
100 auto type = ToString(OperationType::RESIZE);
101 return absl::make_unique<RemoveOperation>(
102 [type](GraphFloat32* graph, Node* node) {
103 if (node->operation.type != type) {
104 return false;
105 }
106 auto inputs = graph->FindInputs(node->id);
107 auto outputs = graph->FindOutputs(node->id);
108 return inputs.size() == 1 && outputs.size() == 1 &&
109 inputs[0]->tensor.shape == outputs[0]->tensor.shape;
110 });
111 }
112
113 class RemoveIdentityReshape : public NodeTransformation {
114 public:
ApplyToNode(Node * node,GraphFloat32 * graph)115 TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
116 if (node->operation.type != ToString(OperationType::RESHAPE)) {
117 return {TransformStatus::SKIPPED, ""};
118 }
119 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
120 const auto& reshape_attr =
121 absl::any_cast<const ReshapeAttributes&>(node->operation.attributes);
122 if (input_shape != reshape_attr.new_shape) {
123 return {TransformStatus::SKIPPED, ""};
124 }
125 auto output = graph->FindOutputs(node->id)[0];
126 const auto& graph_outputs = graph->outputs();
127 if (std::find(graph_outputs.begin(), graph_outputs.end(), output) !=
128 graph_outputs.end()) {
129 return {TransformStatus::SKIPPED,
130 "Can not apply transformation when node output is graph output"};
131 }
132 absl::Status status = RemoveSimpleNodeKeepInput(graph, node);
133 if (!status.ok()) {
134 return {TransformStatus::INVALID,
135 "Unable to remove a node: " + std::string(status.message())};
136 }
137 return {TransformStatus::APPLIED,
138 "Removed reshape with input_shape == output_shape."};
139 }
140 };
141
NewRemoveIdentityReshape()142 std::unique_ptr<NodeTransformation> NewRemoveIdentityReshape() {
143 return absl::make_unique<RemoveIdentityReshape>();
144 }
145
146 class RemoveIdentityStridedSlice : public NodeTransformation {
147 public:
ApplyToNode(Node * node,GraphFloat32 * graph)148 TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
149 if (node->operation.type != ToString(OperationType::SLICE)) {
150 return {TransformStatus::SKIPPED, ""};
151 }
152 auto input = graph->FindInputs(node->id)[0];
153 auto output = graph->FindOutputs(node->id)[0];
154 const auto& slice_attr =
155 absl::any_cast<const SliceAttributes&>(node->operation.attributes);
156 if (input->tensor.shape != output->tensor.shape) {
157 return {TransformStatus::SKIPPED, ""};
158 }
159 if (slice_attr.starts != BHWC(0, 0, 0, 0)) {
160 return {TransformStatus::SKIPPED, ""};
161 }
162 if (slice_attr.strides != BHWC(1, 1, 1, 1)) {
163 return {TransformStatus::SKIPPED, ""};
164 }
165 if (slice_attr.ends != output->tensor.shape) {
166 return {TransformStatus::SKIPPED, ""};
167 }
168 const auto& graph_outputs = graph->outputs();
169 const auto& graph_inputs = graph->inputs();
170 const bool input_is_graph_input =
171 std::find(graph_inputs.begin(), graph_inputs.end(), input) !=
172 graph_inputs.end();
173 const bool output_is_graph_output =
174 std::find(graph_outputs.begin(), graph_outputs.end(), output) !=
175 graph_outputs.end();
176 if (input_is_graph_input && output_is_graph_output) {
177 return {TransformStatus::SKIPPED,
178 "Can not apply transformation when node input is graph input and "
179 "node output is graph output"};
180 }
181 if (output_is_graph_output) {
182 if (graph->FindConsumers(input->id).size() != 1) {
183 return {TransformStatus::SKIPPED,
184 "Can not apply transformation when node output is graph output "
185 "and input consumed by other nodes."};
186 }
187 absl::Status status = RemoveSimpleNodeKeepOutput(graph, node);
188 if (!status.ok()) {
189 return {TransformStatus::INVALID,
190 "Unable to remove a node: " + std::string(status.message())};
191 }
192 return {TransformStatus::APPLIED, "Removed identity strided slice."};
193 }
194 absl::Status status = RemoveSimpleNodeKeepInput(graph, node);
195 if (!status.ok()) {
196 return {TransformStatus::INVALID,
197 "Unable to remove a node: " + std::string(status.message())};
198 }
199 return {TransformStatus::APPLIED, "Removed identity strided slice."};
200 }
201 };
202
NewRemoveIdentityStridedSlice()203 std::unique_ptr<NodeTransformation> NewRemoveIdentityStridedSlice() {
204 return absl::make_unique<RemoveIdentityStridedSlice>();
205 }
206
207 } // namespace gpu
208 } // namespace tflite
209