1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
17
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <variant>
22 #include <vector>
23
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32
33 namespace tflite {
34 namespace gpu {
35 namespace {
36
FuseBiasWithAddAttributes(const ElementwiseAttributes & add_attr,const int channels,Tensor<Linear,DataType::FLOAT32> * bias)37 void FuseBiasWithAddAttributes(const ElementwiseAttributes& add_attr,
38 const int channels,
39 Tensor<Linear, DataType::FLOAT32>* bias) {
40 auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
41 auto add_scalar = absl::get_if<float>(&add_attr.param);
42 if (bias->data.empty()) {
43 *bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(channels));
44 }
45 for (int d = 0; d < channels; ++d) {
46 bias->data[d] += add ? add->data[d] : *add_scalar;
47 }
48 }
49
50 class MergeConvolutionWithAdd : public SequenceTransformation {
51 public:
ExpectedSequenceLength() const52 int ExpectedSequenceLength() const final { return 2; }
53
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)54 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
55 GraphFloat32* graph) final {
56 auto& conv_node = *sequence[0];
57 if (graph->FindInputs(conv_node.id).size() != 1) {
58 return {TransformStatus::DECLINED,
59 "This fusion is only applicable to ops with one runtime input."};
60 }
61 auto& add_node = *sequence[1];
62 if (add_node.operation.type != ToString(OperationType::ADD)) {
63 return {TransformStatus::SKIPPED, ""};
64 }
65 ElementwiseAttributes add_attr =
66 absl::any_cast<ElementwiseAttributes>(add_node.operation.attributes);
67 if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
68 add_attr.param) &&
69 !absl::holds_alternative<float>(add_attr.param)) {
70 return {TransformStatus::DECLINED,
71 "This fuse applicable only for broadcast or scalar addition."};
72 }
73
74 if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
75 Convolution2DAttributes* conv_attr =
76 absl::any_cast<Convolution2DAttributes>(
77 &conv_node.operation.attributes);
78 FuseConvolution2DWithAdd(add_attr, conv_attr);
79 } else if (conv_node.operation.type ==
80 ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
81 ConvolutionTransposedAttributes* conv_attr =
82 absl::any_cast<ConvolutionTransposedAttributes>(
83 &conv_node.operation.attributes);
84 FuseConvolutionTransposedWithAdd(add_attr, conv_attr);
85 } else if (conv_node.operation.type ==
86 ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
87 DepthwiseConvolution2DAttributes* conv_attr =
88 absl::any_cast<DepthwiseConvolution2DAttributes>(
89 &conv_node.operation.attributes);
90 FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr);
91 } else if (conv_node.operation.type ==
92 ToString(OperationType::FULLY_CONNECTED)) {
93 FullyConnectedAttributes* conv_attr =
94 absl::any_cast<FullyConnectedAttributes>(
95 &conv_node.operation.attributes);
96 FuseFullyConnectedWithAdd(add_attr, conv_attr);
97 } else {
98 return {TransformStatus::SKIPPED, ""};
99 }
100
101 absl::Status status = RemoveFollowingNode(graph, &add_node, &conv_node);
102 if (!status.ok()) {
103 return {TransformStatus::INVALID,
104 "Unable to remove add node after convolution: " +
105 std::string(status.message())};
106 }
107 return {TransformStatus::APPLIED, ""};
108 }
109 };
110
FuseAddWithConvolution2D(const ElementwiseAttributes & add_attr,Convolution2DAttributes * attr)111 void FuseAddWithConvolution2D(const ElementwiseAttributes& add_attr,
112 Convolution2DAttributes* attr) {
113 auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
114 auto add_scalar = absl::get_if<float>(&add_attr.param);
115 if (attr->bias.data.empty()) {
116 attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
117 Linear(attr->weights.shape.o));
118 }
119 for (int d = 0; d < attr->weights.shape.o; ++d) {
120 float sum = 0.0f;
121 for (int s = 0; s < attr->weights.shape.i; ++s) {
122 const float add_value = add ? add->data[s] : *add_scalar;
123 for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
124 for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
125 const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
126 sum += add_value * attr->weights.data[index];
127 }
128 }
129 }
130 attr->bias.data[d] += sum;
131 }
132 }
133
134 class MergeAddWithConvolution : public SequenceTransformation {
135 public:
ExpectedSequenceLength() const136 int ExpectedSequenceLength() const final { return 2; }
137
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)138 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
139 GraphFloat32* graph) final {
140 auto& conv_node = *sequence[1];
141 if (graph->FindInputs(conv_node.id).size() != 1) {
142 return {TransformStatus::DECLINED,
143 "This fusion is only applicable to ops with one runtime input."};
144 }
145 auto& add_node = *sequence[0];
146 if (add_node.operation.type != ToString(OperationType::ADD)) {
147 return {TransformStatus::SKIPPED, ""};
148 }
149 ElementwiseAttributes add_attr =
150 absl::any_cast<ElementwiseAttributes>(add_node.operation.attributes);
151 if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
152 add_attr.param) &&
153 !absl::holds_alternative<float>(add_attr.param)) {
154 return {TransformStatus::DECLINED,
155 "This fuse applicable only for broadcast or scalar addition."};
156 }
157
158 if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
159 Convolution2DAttributes* conv_attr =
160 absl::any_cast<Convolution2DAttributes>(
161 &conv_node.operation.attributes);
162 if (conv_attr->groups != 1) {
163 return {TransformStatus::DECLINED,
164 "This fuse not applicable for grouped convolution."};
165 }
166 if (conv_attr->padding.appended.w != 0 ||
167 conv_attr->padding.appended.h != 0 ||
168 conv_attr->padding.prepended.w != 0 ||
169 conv_attr->padding.prepended.h != 0) {
170 return {TransformStatus::DECLINED,
171 "This fuse applicable only for convolution that do not read "
172 "out of bound elements."};
173 }
174 FuseAddWithConvolution2D(add_attr, conv_attr);
175 } else {
176 return {TransformStatus::SKIPPED, ""};
177 }
178
179 absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node);
180 if (!status.ok()) {
181 return {TransformStatus::INVALID,
182 "Unable to remove mul node after convolution: " +
183 std::string(status.message())};
184 }
185 return {TransformStatus::APPLIED, ""};
186 }
187 };
188
189 } // namespace
190
NewMergeConvolutionWithAdd()191 std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
192 return absl::make_unique<MergeConvolutionWithAdd>();
193 }
194
NewMergeAddWithConvolution()195 std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution() {
196 return absl::make_unique<MergeAddWithConvolution>();
197 }
198
FuseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,Convolution2DAttributes * attr)199 void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
200 Convolution2DAttributes* attr) {
201 FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
202 }
203
FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,DepthwiseConvolution2DAttributes * attr)204 void FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
205 DepthwiseConvolution2DAttributes* attr) {
206 FuseBiasWithAddAttributes(
207 add_attr, attr->weights.shape.o * attr->weights.shape.i, &attr->bias);
208 }
209
FuseConvolutionTransposedWithAdd(const ElementwiseAttributes & add_attr,ConvolutionTransposedAttributes * attr)210 void FuseConvolutionTransposedWithAdd(const ElementwiseAttributes& add_attr,
211 ConvolutionTransposedAttributes* attr) {
212 FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
213 }
214
FuseFullyConnectedWithAdd(const ElementwiseAttributes & add_attr,FullyConnectedAttributes * attr)215 void FuseFullyConnectedWithAdd(const ElementwiseAttributes& add_attr,
216 FullyConnectedAttributes* attr) {
217 FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
218 }
219
220 } // namespace gpu
221 } // namespace tflite
222