• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.h"
17 
18 #include <any>
19 #include <memory>
20 #include <string>
21 #include <variant>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/model.h"
27 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/shape.h"
30 #include "tensorflow/lite/delegates/gpu/common/status.h"
31 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
32 
33 namespace tflite {
34 namespace gpu {
35 namespace {
36 
FuseBiasWithAddAttributes(const ElementwiseAttributes & add_attr,const int channels,Tensor<Linear,DataType::FLOAT32> * bias)37 void FuseBiasWithAddAttributes(const ElementwiseAttributes& add_attr,
38                                const int channels,
39                                Tensor<Linear, DataType::FLOAT32>* bias) {
40   auto add = absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&add_attr.param);
41   auto add_scalar = absl::get_if<float>(&add_attr.param);
42   if (bias->data.empty()) {
43     *bias = MakeZeroTensor<Linear, DataType::FLOAT32>(Linear(channels));
44   }
45   for (int d = 0; d < channels; ++d) {
46     bias->data[d] += add ? add->data[d] : *add_scalar;
47   }
48 }
49 
50 class MergeConvolutionWithAdd : public SequenceTransformation {
51  public:
ExpectedSequenceLength() const52   int ExpectedSequenceLength() const final { return 2; }
53 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)54   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
55                                        GraphFloat32* graph) final {
56     auto& conv_node = *sequence[0];
57     if (graph->FindInputs(conv_node.id).size() != 1) {
58       return {TransformStatus::DECLINED,
59               "This fusion is only applicable to ops with one runtime input."};
60     }
61     auto& add_node = *sequence[1];
62     if (add_node.operation.type != ToString(OperationType::ADD)) {
63       return {TransformStatus::SKIPPED, ""};
64     }
65     ElementwiseAttributes add_attr =
66         absl::any_cast<ElementwiseAttributes>(add_node.operation.attributes);
67     if (!absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
68             add_attr.param) &&
69         !absl::holds_alternative<float>(add_attr.param)) {
70       return {TransformStatus::DECLINED,
71               "This fuse applicable only for broadcast or scalar addition."};
72     }
73 
74     if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) {
75       Convolution2DAttributes* conv_attr =
76           absl::any_cast<Convolution2DAttributes>(
77               &conv_node.operation.attributes);
78       FuseConvolution2DWithAdd(add_attr, conv_attr);
79     } else if (conv_node.operation.type ==
80                ToString(OperationType::CONVOLUTION_TRANSPOSED)) {
81       ConvolutionTransposedAttributes* conv_attr =
82           absl::any_cast<ConvolutionTransposedAttributes>(
83               &conv_node.operation.attributes);
84       FuseConvolutionTransposedWithAdd(add_attr, conv_attr);
85     } else if (conv_node.operation.type ==
86                ToString(OperationType::DEPTHWISE_CONVOLUTION)) {
87       DepthwiseConvolution2DAttributes* conv_attr =
88           absl::any_cast<DepthwiseConvolution2DAttributes>(
89               &conv_node.operation.attributes);
90       FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr);
91     } else if (conv_node.operation.type ==
92                ToString(OperationType::FULLY_CONNECTED)) {
93       FullyConnectedAttributes* conv_attr =
94           absl::any_cast<FullyConnectedAttributes>(
95               &conv_node.operation.attributes);
96       FuseFullyConnectedWithAdd(add_attr, conv_attr);
97     } else {
98       return {TransformStatus::SKIPPED, ""};
99     }
100 
101     absl::Status status = RemoveFollowingNode(graph, &add_node, &conv_node);
102     if (!status.ok()) {
103       return {TransformStatus::INVALID,
104               "Unable to remove add node after convolution: " +
105                   std::string(status.message())};
106     }
107     return {TransformStatus::APPLIED, ""};
108   }
109 };
110 
111 }  // namespace
112 
NewMergeConvolutionWithAdd()113 std::unique_ptr<SequenceTransformation> NewMergeConvolutionWithAdd() {
114   return absl::make_unique<MergeConvolutionWithAdd>();
115 }
116 
FuseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,Convolution2DAttributes * attr)117 void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
118                               Convolution2DAttributes* attr) {
119   FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
120 }
121 
FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes & add_attr,DepthwiseConvolution2DAttributes * attr)122 void FuseDepthwiseConvolution2DWithAdd(const ElementwiseAttributes& add_attr,
123                                        DepthwiseConvolution2DAttributes* attr) {
124   FuseBiasWithAddAttributes(
125       add_attr, attr->weights.shape.o * attr->weights.shape.i, &attr->bias);
126 }
127 
FuseConvolutionTransposedWithAdd(const ElementwiseAttributes & add_attr,ConvolutionTransposedAttributes * attr)128 void FuseConvolutionTransposedWithAdd(const ElementwiseAttributes& add_attr,
129                                       ConvolutionTransposedAttributes* attr) {
130   FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
131 }
132 
FuseFullyConnectedWithAdd(const ElementwiseAttributes & add_attr,FullyConnectedAttributes * attr)133 void FuseFullyConnectedWithAdd(const ElementwiseAttributes& add_attr,
134                                FullyConnectedAttributes* attr) {
135   FuseBiasWithAddAttributes(add_attr, attr->weights.shape.o, &attr->bias);
136 }
137 
138 }  // namespace gpu
139 }  // namespace tflite
140