• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
17 
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <variant>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace {
36 
FuseBiasWithAddAttributes(const ElementwiseAttributes & add_attr,const int channels,Tensor<Linear,DataType::FLOAT32> * bias)37 void FuseBiasWithAddAttributes(const ElementwiseAttributes& add_attr,
38                                const int channels,
39                                Tensor<Linear, DataType::FLOAT32>* bias) {
40   auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
41   auto add_scalar = absl::get_if<float>(&add_attr.param);
42   if (bias->data.empty()) {
43     *bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(channels));
44   }
45   for (int d = 0; d < channels; ++d) {
46     bias->data[d] += add ? add->data[d] : *add_scalar;
47   }
48 }
49 
50 class MergeConvolutionWithAdd : public SequenceTransformation {
51  public:
ExpectedSequenceLength() const52   int ExpectedSequenceLength() const final { return 2; }
53 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)54   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
55                                        GraphFloat32* graph) final {
56     auto& conv_node = *sequence[0];
57     if (graph->FindInputs(conv_node.id).size() != 1) {
58       return {TransformStatus::DECLINED,
59               "This fusion is only applicable to ops with one runtime input."};
60     }
61     auto& add_node = *sequence[1];
62     if (add_node.operation.type != ToString(OperationType::ADD)) {
63       return {TransformStatus::SKIPPED, ""};
64     }
65     ElementwiseAttributes add_attr =
66         absl::any_cast<ElementwiseAttributes>(add_node.operation.attributes);
67     if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
68             add_attr.param) &&
69         !absl::holds_alternative<float>(add_attr.param)) {
70       return {TransformStatus::DECLINED,
71               "This fuse applicable only for broadcast or scalar addition."};
72     }
73 
74     if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
75       Convolution2DAttributes* conv_attr =
76           absl::any_cast<Convolution2DAttributes>(
77               &conv_node.operation.attributes);
78       FuseConvolution2DWithAdd(add_attr, conv_attr);
79     } else if (conv_node.operation.type ==
80                ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
81       ConvolutionTransposedAttributes* conv_attr =
82           absl::any_cast<ConvolutionTransposedAttributes>(
83               &conv_node.operation.attributes);
84       FuseConvolutionTransposedWithAdd(add_attr, conv_attr);
85     } else if (conv_node.operation.type ==
86                ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
87       DepthwiseConvolution2DAttributes* conv_attr =
88           absl::any_cast<DepthwiseConvolution2DAttributes>(
89               &conv_node.operation.attributes);
90       FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr);
91     } else if (conv_node.operation.type ==
92                ToString(OperationType::FULLY_CONNECTED)) {
93       FullyConnectedAttributes* conv_attr =
94           absl::any_cast<FullyConnectedAttributes>(
95               &conv_node.operation.attributes);
96       FuseFullyConnectedWithAdd(add_attr, conv_attr);
97     } else {
98       return {TransformStatus::SKIPPED, ""};
99     }
100 
101     absl::Status status = RemoveFollowingNode(graph, &add_node, &conv_node);
102     if (!status.ok()) {
103       return {TransformStatus::INVALID,
104               "Unable to remove add node after convolution: " +
105                   std::string(status.message())};
106     }
107     return {TransformStatus::APPLIED, ""};
108   }
109 };
110 
FuseAddWithConvolution2D(const ElementwiseAttributes & add_attr,Convolution2DAttributes * attr)111 void FuseAddWithConvolution2D(const ElementwiseAttributes& add_attr,
112                               Convolution2DAttributes* attr) {
113   auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
114   auto add_scalar = absl::get_if<float>(&add_attr.param);
115   if (attr->bias.data.empty()) {
116     attr->bias = MakeZeroTensor<Linear, DataType::FLOAT32>(
117         Linear(attr->weights.shape.o));
118   }
119   for (int d = 0; d < attr->weights.shape.o; ++d) {
120     float sum = 0.0f;
121     for (int s = 0; s < attr->weights.shape.i; ++s) {
122       const float add_value = add ? add->data[s] : *add_scalar;
123       for (int k_y = 0; k_y < attr->weights.shape.h; ++k_y) {
124         for (int k_x = 0; k_x < attr->weights.shape.w; ++k_x) {
125           const int index = attr->weights.shape.LinearIndex({{d, k_y, k_x, s}});
126           sum += add_value * attr->weights.data[index];
127         }
128       }
129     }
130     attr->bias.data[d] += sum;
131   }
132 }
133 
134 class MergeAddWithConvolution : public SequenceTransformation {
135  public:
ExpectedSequenceLength() const136   int ExpectedSequenceLength() const final { return 2; }
137 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)138   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
139                                        GraphFloat32* graph) final {
140     auto& conv_node = *sequence[1];
141     if (graph->FindInputs(conv_node.id).size() != 1) {
142       return {TransformStatus::DECLINED,
143               "This fusion is only applicable to ops with one runtime input."};
144     }
145     auto& add_node = *sequence[0];
146     if (add_node.operation.type != ToString(OperationType::ADD)) {
147       return {TransformStatus::SKIPPED, ""};
148     }
149     ElementwiseAttributes add_attr =
150         absl::any_cast<ElementwiseAttributes>(add_node.operation.attributes);
151     if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
152             add_attr.param) &&
153         !absl::holds_alternative<float>(add_attr.param)) {
154       return {TransformStatus::DECLINED,
155               "This fuse applicable only for broadcast or scalar addition."};
156     }
157 
158     if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
159       Convolution2DAttributes* conv_attr =
160           absl::any_cast<Convolution2DAttributes>(
161               &conv_node.operation.attributes);
162       if (conv_attr->groups != 1) {
163         return {TransformStatus::DECLINED,
164                 "This fuse not applicable for grouped convolution."};
165       }
166       if (conv_attr->padding.appended.w != 0 ||
167           conv_attr->padding.appended.h != 0 ||
168           conv_attr->padding.prepended.w != 0 ||
169           conv_attr->padding.prepended.h != 0) {
170         return {TransformStatus::DECLINED,
171                 "This fuse applicable only for convolution that do not read "
172                 "out of bound elements."};
173       }
174       FuseAddWithConvolution2D(add_attr, conv_attr);
175     } else {
176       return {TransformStatus::SKIPPED, ""};
177     }
178 
179     absl::Status status = RemovePrecedingNode(graph, &add_node, &conv_node);
180     if (!status.ok()) {
181       return {TransformStatus::INVALID,
182               "Unable to remove mul node after convolution: " +
183                   std::string(status.message())};
184     }
185     return {TransformStatus::APPLIED, ""};
186   }
187 };
188 
189 }  // namespace
190 
NewMergeConvolutionWithAdd()191 std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
192   return absl::make_unique<MergeConvolutionWithAdd>();
193 }
194 
NewMergeAddWithConvolution()195 std::unique_ptr<SequenceTransformation> NewMergeAddWithConvolution() {
196   return absl::make_unique<MergeAddWithConvolution>();
197 }
198 
FuseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,Convolution2DAttributes * attr)199 void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
200                               Convolution2DAttributes* attr) {
201   FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
202 }
203 
FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,DepthwiseConvolution2DAttributes * attr)204 void FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
205                                        DepthwiseConvolution2DAttributes* attr) {
206   FuseBiasWithAddAttributes(
207       add_attr, attr->weights.shape.o * attr->weights.shape.i, &attr->bias);
208 }
209 
FuseConvolutionTransposedWithAdd(const ElementwiseAttributes & add_attr,ConvolutionTransposedAttributes * attr)210 void FuseConvolutionTransposedWithAdd(const ElementwiseAttributes& add_attr,
211                                       ConvolutionTransposedAttributes* attr) {
212   FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
213 }
214 
FuseFullyConnectedWithAdd(const ElementwiseAttributes & add_attr,FullyConnectedAttributes * attr)215 void FuseFullyConnectedWithAdd(const ElementwiseAttributes& add_attr,
216                                FullyConnectedAttributes* attr) {
217   FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
218 }
219 
220 }  // namespace gpu
221 }  // namespace tflite
222