1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
17
18 #include <memory>
19 #include <string>
20 #include <variant>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/any.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
30 #include "tensorflow/lite/delegates/gpu/common/operations.h"
31 #include "tensorflow/lite/delegates/gpu/common/shape.h"
32 #include "tensorflow/lite/delegates/gpu/common/status.h"
33 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
34 #include "tensorflow/lite/delegates/gpu/common/transformations/matching.h"
35
36 namespace tflite {
37 namespace gpu {
38 namespace {
39
40 template <typename Attr>
41 class MergePaddingWith2DOperation : public SequenceTransformation {
42 public:
MergePaddingWith2DOperation(OperationType operation_type)43 explicit MergePaddingWith2DOperation(OperationType operation_type)
44 : operations_to_match_(
45 {ToString(OperationType::PAD), ToString(operation_type)}) {}
46
ExpectedSequenceLength() const47 int ExpectedSequenceLength() const final { return 2; }
48
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)49 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
50 GraphFloat32* graph) final {
51 if (!MatchesByOperationType(sequence, operations_to_match_)) {
52 return {TransformStatus::SKIPPED, ""};
53 }
54
55 Node* pad_node = sequence.front();
56 Node* op_node = sequence.back();
57
58 PadAttributes pad_attr =
59 absl::any_cast<PadAttributes>(pad_node->operation.attributes);
60
61 if (pad_attr.type != PaddingContentType::ZEROS) {
62 return {TransformStatus::DECLINED, "Only Zero padding is supported."};
63 }
64 if (pad_attr.appended.c != 0 || pad_attr.prepended.c != 0 ||
65 pad_attr.appended.b != 0 || pad_attr.prepended.b != 0) {
66 return {TransformStatus::DECLINED,
67 "Pad has non-zero padding on non HW axis."};
68 }
69
70 Attr* node_attr = absl::any_cast<Attr>(&op_node->operation.attributes);
71 absl::Status status = RemovePrecedingNode(graph, pad_node, op_node);
72 if (!status.ok()) {
73 return {TransformStatus::INVALID,
74 "Unable to remove Pad node with Operation node: " +
75 std::string(status.message())};
76 }
77
78 node_attr->padding.appended.h += pad_attr.appended.h;
79 node_attr->padding.appended.w += pad_attr.appended.w;
80 node_attr->padding.prepended.h += pad_attr.prepended.h;
81 node_attr->padding.prepended.w += pad_attr.prepended.w;
82 return {
83 TransformStatus::APPLIED,
84 absl::StrCat("Added padding: prepended = {h = ", pad_attr.prepended.h,
85 ", w = ", pad_attr.prepended.w, "}, appended = { h = ",
86 pad_attr.appended.h, ", w = ", pad_attr.appended.w, "}")};
87 }
88
89 private:
90 const std::vector<std::string> operations_to_match_;
91 };
92
93 } // namespace
94
NewMergePaddingWithPooling()95 std::unique_ptr<SequenceTransformation> NewMergePaddingWithPooling() {
96 return absl::make_unique<MergePaddingWith2DOperation<Pooling2DAttributes>>(
97 OperationType::POOLING_2D);
98 }
99
NewMergePaddingWithConvolution2D()100 std::unique_ptr<SequenceTransformation> NewMergePaddingWithConvolution2D() {
101 return absl::make_unique<
102 MergePaddingWith2DOperation<Convolution2DAttributes>>(
103 OperationType::CONVOLUTION_2D);
104 }
105
106 std::unique_ptr<SequenceTransformation>
NewMergePaddingWithDepthwiseConvolution()107 NewMergePaddingWithDepthwiseConvolution() {
108 return absl::make_unique<
109 MergePaddingWith2DOperation<DepthwiseConvolution2DAttributes>>(
110 OperationType::DEPTHWISE_CONVOLUTION);
111 }
112
113 class MergePaddingWithAddOperation : public NodeTransformation {
114 public:
ApplyToNode(Node * node,GraphFloat32 * graph)115 TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
116 if (node->operation.type != ToString(OperationType::PAD)) {
117 return {TransformStatus::SKIPPED, ""};
118 }
119 auto inputs = graph->FindInputs(node->id);
120 if (inputs.size() != 1) {
121 return {TransformStatus::SKIPPED, ""};
122 }
123
124 const auto& input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
125 if (input_shape.c % 4 != 0) {
126 return {TransformStatus::DECLINED,
127 "Pad with input where src_channels % 4 != 0"};
128 }
129
130 PadAttributes pad_attr =
131 absl::any_cast<PadAttributes>(node->operation.attributes);
132
133 if (pad_attr.type != PaddingContentType::ZEROS) {
134 return {TransformStatus::DECLINED, "Only Zero padding is supported."};
135 }
136 if (pad_attr.prepended != BHWC(0, 0, 0, 0) || pad_attr.appended.h != 0 ||
137 pad_attr.appended.w != 0 || pad_attr.appended.b != 0) {
138 return {TransformStatus::DECLINED,
139 "Pad has padding not only in appended channels axis."};
140 }
141
142 auto pad_output = graph->FindOutputs(node->id)[0];
143 auto consumer_nodes = graph->FindConsumers(pad_output->id);
144 if (consumer_nodes.size() != 1) {
145 return {TransformStatus::SKIPPED, ""};
146 }
147 auto add_node = consumer_nodes[0];
148 auto consumer_type = OperationTypeFromString(add_node->operation.type);
149 if (consumer_type != OperationType::ADD) {
150 return {TransformStatus::SKIPPED, ""};
151 }
152
153 ElementwiseAttributes add_attr =
154 absl::any_cast<ElementwiseAttributes>(add_node->operation.attributes);
155 const bool is_add_hwc =
156 absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(add_attr.param);
157 const bool is_add_linear =
158 absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
159 add_attr.param);
160 const bool is_add_scalar = absl::holds_alternative<float>(add_attr.param);
161 if (is_add_hwc || is_add_linear || is_add_scalar) {
162 return {TransformStatus::SKIPPED,
163 "Cannot remove padding when ADD has constant argument."};
164 }
165
166 absl::Status status = RemovePrecedingNode(graph, node, add_node);
167 if (!status.ok()) {
168 return {TransformStatus::INVALID,
169 "Unable to remove Pad node " + std::string(status.message())};
170 }
171
172 return {TransformStatus::APPLIED,
173 "Removed padding with zeroes in appended channels dimension"};
174 }
175 };
176
NewMergePaddingWithAdd()177 std::unique_ptr<NodeTransformation> NewMergePaddingWithAdd() {
178 return absl::make_unique<MergePaddingWithAddOperation>();
179 }
180
181 } // namespace gpu
182 } // namespace tflite
183