1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
17
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <variant>
22 #include <vector>
23
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32
33 namespace tflite {
34 namespace gpu {
35 namespace {
36
FuseBiasWithAddAttributes(const ElementwiseAttributes & add_attr,const int channels,Tensor<Linear,DataType::FLOAT32> * bias)37 void FuseBiasWithAddAttributes(const ElementwiseAttributes& add_attr,
38 const int channels,
39 Tensor<Linear, DataType::FLOAT32>* bias) {
40 auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
41 auto add_scalar = absl::get_if<float>(&add_attr.param);
42 if (bias->data.empty()) {
43 *bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(channels));
44 }
45 for (int d = 0; d < channels; ++d) {
46 bias->data[d] += add ? add->data[d] : *add_scalar;
47 }
48 }
49
50 class MergeConvolutionWithAdd : public SequenceTransformation {
51 public:
ExpectedSequenceLength() const52 int ExpectedSequenceLength() const final { return 2; }
53
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)54 TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
55 GraphFloat32* graph) final {
56 auto& conv_node = *sequence[0];
57 if (graph->FindInputs(conv_node.id).size() != 1) {
58 return {TransformStatus::DECLINED,
59 "This fusion is only applicable to ops with one runtime input."};
60 }
61 auto& add_node = *sequence[1];
62 if (add_node.operation.type != ToString(OperationType::ADD)) {
63 return {TransformStatus::SKIPPED, ""};
64 }
65 ElementwiseAttributes add_attr =
66 absl::any_cast<ElementwiseAttributes>(add_node.operation.attributes);
67 if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
68 add_attr.param) &&
69 !absl::holds_alternative<float>(add_attr.param)) {
70 return {TransformStatus::DECLINED,
71 "This fuse applicable only for broadcast or scalar addition."};
72 }
73
74 if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
75 Convolution2DAttributes* conv_attr =
76 absl::any_cast<Convolution2DAttributes>(
77 &conv_node.operation.attributes);
78 FuseConvolution2DWithAdd(add_attr, conv_attr);
79 } else if (conv_node.operation.type ==
80 ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
81 ConvolutionTransposedAttributes* conv_attr =
82 absl::any_cast<ConvolutionTransposedAttributes>(
83 &conv_node.operation.attributes);
84 FuseConvolutionTransposedWithAdd(add_attr, conv_attr);
85 } else if (conv_node.operation.type ==
86 ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
87 DepthwiseConvolution2DAttributes* conv_attr =
88 absl::any_cast<DepthwiseConvolution2DAttributes>(
89 &conv_node.operation.attributes);
90 FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr);
91 } else if (conv_node.operation.type ==
92 ToString(OperationType::FULLY_CONNECTED)) {
93 FullyConnectedAttributes* conv_attr =
94 absl::any_cast<FullyConnectedAttributes>(
95 &conv_node.operation.attributes);
96 FuseFullyConnectedWithAdd(add_attr, conv_attr);
97 } else {
98 return {TransformStatus::SKIPPED, ""};
99 }
100
101 absl::Status status = RemoveFollowingNode(graph, &add_node, &conv_node);
102 if (!status.ok()) {
103 return {TransformStatus::INVALID,
104 "Unable to remove add node after convolution: " +
105 std::string(status.message())};
106 }
107 return {TransformStatus::APPLIED, ""};
108 }
109 };
110
111 } // namespace
112
NewMergeConvolutionWithAdd()113 std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
114 return absl::make_unique<MergeConvolutionWithAdd>();
115 }
116
FuseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,Convolution2DAttributes * attr)117 void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
118 Convolution2DAttributes* attr) {
119 FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
120 }
121
FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,DepthwiseConvolution2DAttributes * attr)122 void FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
123 DepthwiseConvolution2DAttributes* attr) {
124 FuseBiasWithAddAttributes(
125 add_attr, attr->weights.shape.o * attr->weights.shape.i, &attr->bias);
126 }
127
FuseConvolutionTransposedWithAdd(const ElementwiseAttributes & add_attr,ConvolutionTransposedAttributes * attr)128 void FuseConvolutionTransposedWithAdd(const ElementwiseAttributes& add_attr,
129 ConvolutionTransposedAttributes* attr) {
130 FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
131 }
132
FuseFullyConnectedWithAdd(const ElementwiseAttributes & add_attr,FullyConnectedAttributes * attr)133 void FuseFullyConnectedWithAdd(const ElementwiseAttributes& add_attr,
134 FullyConnectedAttributes* attr) {
135 FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
136 }
137
138 } // namespace gpu
139 } // namespace tflite
140