• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/remove_noop.h"
17 
18 #include <algorithm>
19 #include <any>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <variant>
26 #include <vector>
27 
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
31 #include "tensorflow/lite/delegates/gpu/common/model.h"
32 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
33 #include "tensorflow/lite/delegates/gpu/common/operations.h"
34 #include "tensorflow/lite/delegates/gpu/common/shape.h"
35 #include "tensorflow/lite/delegates/gpu/common/status.h"
36 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
37 
38 namespace tflite {
39 namespace gpu {
40 namespace {
41 
42 using ShouldRemoveOperation = std::function<bool(GraphFloat32* graph, Node*)>;
43 
44 class RemoveOperation : public SequenceTransformation {
45  public:
RemoveOperation(ShouldRemoveOperation remove_predicate)46   explicit RemoveOperation(ShouldRemoveOperation remove_predicate)
47       : remove_predicate_(std::move(remove_predicate)) {}
48 
ExpectedSequenceLength() const49   int ExpectedSequenceLength() const final { return 2; }
50 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)51   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
52                                        GraphFloat32* graph) final {
53     Node* prev_op_node = sequence.front();
54     Node* op_node = sequence.back();
55     if (!remove_predicate_(graph, op_node)) {
56       return {TransformStatus::SKIPPED, ""};
57     }
58     absl::Status status = RemoveFollowingNode(graph, op_node, prev_op_node);
59     if (!status.ok()) {
60       return {TransformStatus::INVALID,
61               "Unable to remove a node: " + std::string(status.message())};
62     }
63     return {TransformStatus::APPLIED, ""};
64   }
65 
66  private:
67   ShouldRemoveOperation remove_predicate_;
68 };
69 
70 }  // namespace
71 
NewRemoveSingleInputConcat()72 std::unique_ptr<SequenceTransformation> NewRemoveSingleInputConcat() {
73   // Using SequenceTransformation implies that CONCAT has a single input.
74   auto type = ToString(OperationType::CONCAT);
75   return absl::make_unique<RemoveOperation>(
76       [type](GraphFloat32* graph, Node* node) {
77         return type == node->operation.type;
78       });
79 }
80 
NewRemoveSingleInputAdd()81 std::unique_ptr<SequenceTransformation> NewRemoveSingleInputAdd() {
82   // Using SequenceTransformation implies that ADD has a single input.
83   auto type = ToString(OperationType::ADD);
84   return absl::make_unique<RemoveOperation>(
85       [type](GraphFloat32* graph, Node* node) {
86         if (node->operation.type != type) {
87           return false;
88         }
89         auto& attr = absl::any_cast<const ElementwiseAttributes&>(
90             node->operation.attributes);
91         return !absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(
92                    attr.param) &&
93                !absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
94                    attr.param) &&
95                !absl::holds_alternative<float>(attr.param);
96       });
97 }
98 
NewRemoveDegenerateUpsampling()99 std::unique_ptr<SequenceTransformation> NewRemoveDegenerateUpsampling() {
100   auto type = ToString(OperationType::RESIZE);
101   return absl::make_unique<RemoveOperation>(
102       [type](GraphFloat32* graph, Node* node) {
103         if (node->operation.type != type) {
104           return false;
105         }
106         auto inputs = graph->FindInputs(node->id);
107         auto outputs = graph->FindOutputs(node->id);
108         return inputs.size() == 1 && outputs.size() == 1 &&
109                inputs[0]->tensor.shape == outputs[0]->tensor.shape;
110       });
111 }
112 
113 class RemoveIdentityReshape : public NodeTransformation {
114  public:
ApplyToNode(Node * node,GraphFloat32 * graph)115   TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
116     if (node->operation.type != ToString(OperationType::RESHAPE)) {
117       return {TransformStatus::SKIPPED, ""};
118     }
119     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
120     const auto& reshape_attr =
121         absl::any_cast<const ReshapeAttributes&>(node->operation.attributes);
122     if (input_shape != reshape_attr.new_shape) {
123       return {TransformStatus::SKIPPED, ""};
124     }
125     auto output = graph->FindOutputs(node->id)[0];
126     const auto& graph_outputs = graph->outputs();
127     if (std::find(graph_outputs.begin(), graph_outputs.end(), output) !=
128         graph_outputs.end()) {
129       return {TransformStatus::SKIPPED,
130               "Can not apply transformation when node output is graph output"};
131     }
132     absl::Status status = RemoveSimpleNodeKeepInput(graph, node);
133     if (!status.ok()) {
134       return {TransformStatus::INVALID,
135               "Unable to remove a node: " + std::string(status.message())};
136     }
137     return {TransformStatus::APPLIED,
138             "Removed reshape with input_shape == output_shape."};
139   }
140 };
141 
NewRemoveIdentityReshape()142 std::unique_ptr<NodeTransformation> NewRemoveIdentityReshape() {
143   return absl::make_unique<RemoveIdentityReshape>();
144 }
145 
146 }  // namespace gpu
147 }  // namespace tflite
148