• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/transformations/merge_padding_with.h"
17 
18 #include <memory>
19 #include <string>
20 #include <variant>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/any.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
30 #include "tensorflow/lite/delegates/gpu/common/operations.h"
31 #include "tensorflow/lite/delegates/gpu/common/shape.h"
32 #include "tensorflow/lite/delegates/gpu/common/status.h"
33 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
34 #include "tensorflow/lite/delegates/gpu/common/transformations/matching.h"
35 
36 namespace tflite {
37 namespace gpu {
38 namespace {
39 
40 template <typename Attr>
41 class MergePaddingWith2DOperation : public SequenceTransformation {
42  public:
MergePaddingWith2DOperation(OperationType operation_type)43   explicit MergePaddingWith2DOperation(OperationType operation_type)
44       : operations_to_match_(
45             {ToString(OperationType::PAD), ToString(operation_type)}) {}
46 
ExpectedSequenceLength() const47   int ExpectedSequenceLength() const final { return 2; }
48 
ApplyToNodesSequence(const std::vector<Node * > & sequence,GraphFloat32 * graph)49   TransformResult ApplyToNodesSequence(const std::vector<Node*>& sequence,
50                                        GraphFloat32* graph) final {
51     if (!MatchesByOperationType(sequence, operations_to_match_)) {
52       return {TransformStatus::SKIPPED, ""};
53     }
54 
55     Node* pad_node = sequence.front();
56     Node* op_node = sequence.back();
57 
58     PadAttributes pad_attr =
59         absl::any_cast<PadAttributes>(pad_node->operation.attributes);
60 
61     if (pad_attr.type != PaddingContentType::ZEROS) {
62       return {TransformStatus::DECLINED, "Only Zero padding is supported."};
63     }
64     if (pad_attr.appended.c != 0 || pad_attr.prepended.c != 0 ||
65         pad_attr.appended.b != 0 || pad_attr.prepended.b != 0) {
66       return {TransformStatus::DECLINED,
67               "Pad has non-zero padding on non HW axis."};
68     }
69 
70     Attr* node_attr = absl::any_cast<Attr>(&op_node->operation.attributes);
71     absl::Status status = RemovePrecedingNode(graph, pad_node, op_node);
72     if (!status.ok()) {
73       return {TransformStatus::INVALID,
74               "Unable to remove Pad node with Operation node: " +
75                   std::string(status.message())};
76     }
77 
78     node_attr->padding.appended.h += pad_attr.appended.h;
79     node_attr->padding.appended.w += pad_attr.appended.w;
80     node_attr->padding.prepended.h += pad_attr.prepended.h;
81     node_attr->padding.prepended.w += pad_attr.prepended.w;
82     return {
83         TransformStatus::APPLIED,
84         absl::StrCat("Added padding: prepended = {h = ", pad_attr.prepended.h,
85                      ", w = ", pad_attr.prepended.w, "}, appended = { h = ",
86                      pad_attr.appended.h, ", w = ", pad_attr.appended.w, "}")};
87   }
88 
89  private:
90   const std::vector<std::string> operations_to_match_;
91 };
92 
93 }  // namespace
94 
NewMergePaddingWithPooling()95 std::unique_ptr<SequenceTransformation> NewMergePaddingWithPooling() {
96   return absl::make_unique<MergePaddingWith2DOperation<Pooling2DAttributes>>(
97       OperationType::POOLING_2D);
98 }
99 
NewMergePaddingWithConvolution2D()100 std::unique_ptr<SequenceTransformation> NewMergePaddingWithConvolution2D() {
101   return absl::make_unique<
102       MergePaddingWith2DOperation<Convolution2DAttributes>>(
103       OperationType::CONVOLUTION_2D);
104 }
105 
106 std::unique_ptr<SequenceTransformation>
NewMergePaddingWithDepthwiseConvolution()107 NewMergePaddingWithDepthwiseConvolution() {
108   return absl::make_unique<
109       MergePaddingWith2DOperation<DepthwiseConvolution2DAttributes>>(
110       OperationType::DEPTHWISE_CONVOLUTION);
111 }
112 
113 class MergePaddingWithAddOperation : public NodeTransformation {
114  public:
ApplyToNode(Node * node,GraphFloat32 * graph)115   TransformResult ApplyToNode(Node* node, GraphFloat32* graph) final {
116     if (node->operation.type != ToString(OperationType::PAD)) {
117       return {TransformStatus::SKIPPED, ""};
118     }
119     auto inputs = graph->FindInputs(node->id);
120     if (inputs.size() != 1) {
121       return {TransformStatus::SKIPPED, ""};
122     }
123 
124     const auto& input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
125     if (input_shape.c % 4 != 0) {
126       return {TransformStatus::DECLINED,
127               "Pad with input where src_channels % 4 != 0"};
128     }
129 
130     PadAttributes pad_attr =
131         absl::any_cast<PadAttributes>(node->operation.attributes);
132 
133     if (pad_attr.type != PaddingContentType::ZEROS) {
134       return {TransformStatus::DECLINED, "Only Zero padding is supported."};
135     }
136     if (pad_attr.prepended != BHWC(0, 0, 0, 0) || pad_attr.appended.h != 0 ||
137         pad_attr.appended.w != 0 || pad_attr.appended.b != 0) {
138       return {TransformStatus::DECLINED,
139               "Pad has padding not only in appended channels axis."};
140     }
141 
142     auto pad_output = graph->FindOutputs(node->id)[0];
143     auto consumer_nodes = graph->FindConsumers(pad_output->id);
144     if (consumer_nodes.size() != 1) {
145       return {TransformStatus::SKIPPED, ""};
146     }
147     auto add_node = consumer_nodes[0];
148     auto consumer_type = OperationTypeFromString(add_node->operation.type);
149     if (consumer_type != OperationType::ADD) {
150       return {TransformStatus::SKIPPED, ""};
151     }
152 
153     ElementwiseAttributes add_attr =
154         absl::any_cast<ElementwiseAttributes>(add_node->operation.attributes);
155     const bool is_add_hwc =
156         absl::holds_alternative<Tensor<HWC, DataType::FLOAT32>>(add_attr.param);
157     const bool is_add_linear =
158         absl::holds_alternative<Tensor<Linear, DataType::FLOAT32>>(
159             add_attr.param);
160     const bool is_add_scalar = absl::holds_alternative<float>(add_attr.param);
161     if (is_add_hwc || is_add_linear || is_add_scalar) {
162       return {TransformStatus::SKIPPED,
163               "Cannot remove padding when ADD has constant argument."};
164     }
165 
166     absl::Status status = RemovePrecedingNode(graph, node, add_node);
167     if (!status.ok()) {
168       return {TransformStatus::INVALID,
169               "Unable to remove Pad node " + std::string(status.message())};
170     }
171 
172     return {TransformStatus::APPLIED,
173             "Removed padding with zeroes in appended channels dimension"};
174   }
175 };
176 
NewMergePaddingWithAdd()177 std::unique_ptr<NodeTransformation> NewMergePaddingWithAdd() {
178   return absl::make_unique<MergePaddingWithAddOperation>();
179 }
180 
181 }  // namespace gpu
182 }  // namespace tflite
183