• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
31 #include "tensorflow/lite/delegates/gpu/common/model.h"
32 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
33 #include "tensorflow/lite/delegates/gpu/common/operations.h"
34 #include "tensorflow/lite/delegates/gpu/common/shape.h"
35 #include "tensorflow/lite/delegates/gpu/common/status.h"
36 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
37 
38 namespace tflite {
39 namespace gpu {
40 namespace {
41 
42 using ShouldRemoveOperation = std::function<bool(GraphFloat32* graph, Node*)>;
43 
44 class RemoveOperation : public SequenceTransformation {
45  public:
RemoveOperation(ShouldRemoveOperation remove_predicate)46   explicit RemoveOperation(ShouldRemoveOperation remove_predicate)
47       : remove_predicate_(std::move(remove_predicate)) {}
48 
ExpectedSequenceLength() const49   int ExpectedSequenceLength() const final { return 2; }
50 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)51   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
52                                        GraphFloat32* graph) final {
53     Node* prev_op_node = sequence.front();
54     Node* op_node = sequence.back();
55     if (!remove_predicate_(graph, op_node)) {
56       return {TransformStatus::SKIPPED, ""};
57     }
58     absl::Status status = RemoveFollowingNode(graph, op_node, prev_op_node);
59     if (!status.ok()) {
60       return {TransformStatus::INVALID,
61               "Unable to remove a node: " + std::string(status.message())};
62     }
63     return {TransformStatus::APPLIED, ""};
64   }
65 
66  private:
67   ShouldRemoveOperation remove_predicate_;
68 };
69 
70 }  // namespace
71 
NewRemoveSingleInputConcat()72 std::unique_ptr<SequenceTransformation> NewRemoveSingleInputConcat() {
73   // Using SequenceTransformation implies that CONCAT has a single input.
74   auto type = ToString(OperationType::CONCAT);
75   return absl::make_unique<RemoveOperation>(
76       [type](GraphFloat32* graph, Node* node) {
77         return type == node->operation.type;
78       });
79 }
80 
NewRemoveSingleInputAdd()81 std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
82   // Using SequenceTransformation implies that ADD has a single input.
83   auto type = ToString(OperationType::ADD);
84   return absl::make_unique<RemoveOperation>(
85       [type](GraphFloat32* graph, Node* node) {
86         if (node->operation.type != type) {
87           return false;
88         }
89         auto& attr = absl::any_cast<const ElementwiseAttributes&>(
90             node->operation.attributes);
91         return !absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(
92                    attr.param) &&
93                !absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
94                    attr.param) &&
95                !absl::holds_alternative<float>(attr.param);
96       });
97 }
98 
NewRemoveDegenerateUpsampling()99 std::unique_ptr<SequenceTransformation> NewRemoveDegenerateUpsampling() {
100   auto type = ToString(OperationType::RESIZE);
101   return absl::make_unique<RemoveOperation>(
102       [type](GraphFloat32* graph, Node* node) {
103         if (node->operation.type != type) {
104           return false;
105         }
106         auto inputs = graph->FindInputs(node->id);
107         auto outputs = graph->FindOutputs(node->id);
108         return inputs.size() == 1 && outputs.size() == 1 &&
109                inputs[0]->tensor.shape == outputs[0]->tensor.shape;
110       });
111 }
112 
113 class RemoveIdentityReshape : public NodeTransformation {
114  public:
ApplyToNode(Node * node,GraphFloat32 * graph)115   TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
116     if (node->operation.type != ToString(OperationType::RESHAPE)) {
117       return {TransformStatus::SKIPPED, ""};
118     }
119     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
120     const auto& reshape_attr =
121         absl::any_cast<const ReshapeAttributes&>(node->operation.attributes);
122     if (input_shape != reshape_attr.new_shape) {
123       return {TransformStatus::SKIPPED, ""};
124     }
125     auto output = graph->FindOutputs(node->id)[0];
126     const auto& graph_outputs = graph->outputs();
127     if (std::find(graph_outputs.begin(), graph_outputs.end(), output) !=
128         graph_outputs.end()) {
129       return {TransformStatus::SKIPPED,
130               "Can not apply transformation when node output is graph output"};
131     }
132     absl::Status status = RemoveSimpleNodeKeepInput(graph, node);
133     if (!status.ok()) {
134       return {TransformStatus::INVALID,
135               "Unable to remove a node: " + std::string(status.message())};
136     }
137     return {TransformStatus::APPLIED,
138             "Removed reshape with input_shape == output_shape."};
139   }
140 };
141 
NewRemoveIdentityReshape()142 std::unique_ptr<NodeTransformation> NewRemoveIdentityReshape() {
143   return absl::make_unique<RemoveIdentityReshape>();
144 }
145 
146 class RemoveIdentityStridedSlice : public NodeTransformation {
147  public:
ApplyToNode(Node * node,GraphFloat32 * graph)148   TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
149     if (node->operation.type != ToString(OperationType::SLICE)) {
150       return {TransformStatus::SKIPPED, ""};
151     }
152     auto input = graph->FindInputs(node->id)[0];
153     auto output = graph->FindOutputs(node->id)[0];
154     const auto& slice_attr =
155         absl::any_cast<const SliceAttributes&>(node->operation.attributes);
156     if (input->tensor.shape != output->tensor.shape) {
157       return {TransformStatus::SKIPPED, ""};
158     }
159     if (slice_attr.starts != BHWC(0, 0, 0, 0)) {
160       return {TransformStatus::SKIPPED, ""};
161     }
162     if (slice_attr.strides != BHWC(1, 1, 1, 1)) {
163       return {TransformStatus::SKIPPED, ""};
164     }
165     if (slice_attr.ends != output->tensor.shape) {
166       return {TransformStatus::SKIPPED, ""};
167     }
168     const auto& graph_outputs = graph->outputs();
169     const auto& graph_inputs = graph->inputs();
170     const bool input_is_graph_input =
171         std::find(graph_inputs.begin(), graph_inputs.end(), input) !=
172         graph_inputs.end();
173     const bool output_is_graph_output =
174         std::find(graph_outputs.begin(), graph_outputs.end(), output) !=
175         graph_outputs.end();
176     if (input_is_graph_input && output_is_graph_output) {
177       return {TransformStatus::SKIPPED,
178               "Can not apply transformation when node input is graph input and "
179               "node output is graph output"};
180     }
181     if (output_is_graph_output) {
182       if (graph->FindConsumers(input->id).size() != 1) {
183         return {TransformStatus::SKIPPED,
184                 "Can not apply transformation when node output is graph output "
185                 "and input consumed by other nodes."};
186       }
187       absl::Status status = RemoveSimpleNodeKeepOutput(graph, node);
188       if (!status.ok()) {
189         return {TransformStatus::INVALID,
190                 "Unable to remove a node: " + std::string(status.message())};
191       }
192       return {TransformStatus::APPLIED, "Removed identity strided slice."};
193     }
194     absl::Status status = RemoveSimpleNodeKeepInput(graph, node);
195     if (!status.ok()) {
196       return {TransformStatus::INVALID,
197               "Unable to remove a node: " + std::string(status.message())};
198     }
199     return {TransformStatus::APPLIED, "Removed identity strided slice."};
200   }
201 };
202 
NewRemoveIdentityStridedSlice()203 std::unique_ptr<NodeTransformation> NewRemoveIdentityStridedSlice() {
204   return absl::make_unique<RemoveIdentityStridedSlice>();
205 }
206 
207 }  // namespace gpu
208 }  // namespace tflite
209