1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
17
18 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
20
21 namespace mlir {
22 namespace TFL {
23
24 namespace {
25
26 // TODO(b/162842801): Consolidate all util definitions of kTFImplements.
27 constexpr char kTFImplements[] = "tf._implements";
28 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
29 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
30
CustomOption(OpBuilder * builder,const std::string & content)31 inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
32 const std::string& content) {
33 ShapedType type = RankedTensorType::get(
34 {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
35 return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
36 type,
37 StringRef(content.data(), content.size()));
38 }
39
40 } // namespace
41
RewriteFunc()42 void ConvertNMSPaddedFunc::RewriteFunc() {
43 func_->setAttr(kTFImplements,
44 StringAttr::get(func_.getContext(), kTfNMSPadded));
45 Value boxes = func_.getArgument(0);
46 Value scores = func_.getArgument(1);
47 Value max_output_size = func_.getArgument(2);
48 Value iou_threshold = func_.getArgument(3);
49 Value score_threshold = func_.getArgument(4);
50 auto output_type0 = func_.getType().getResult(0);
51 auto output_type1 = func_.getType().getResult(1);
52
53 OpBuilder builder(func_.getBody());
54 auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
55 func_.getLoc(), output_type0, output_type1, boxes, scores,
56 max_output_size, iou_threshold, score_threshold);
57
58 builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
59 }
60
VerifySignature()61 LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
62 // Verify high-level function signature.
63 // Relevant argument characteristics are checked by the TFL op definition.
64 if (func_.getNumArguments() < 5) {
65 return func_.emitWarning()
66 << "Invalid number of arguments to "
67 "non_max_suppression_padded_v2 (need at least 5): "
68 << func_.getNumArguments();
69 }
70 if (func_.getType().getNumResults() != 2) {
71 return func_.emitWarning() << "Invalid number of results from "
72 "non_max_suppression_padded_v2 (need 2): "
73 << func_.getType().getNumResults();
74 }
75 // The TFLite fused op does not support batching yet.
76 // TODO(b/158709815): Add support for batches with padded NMS.
77 auto boxes_type = func_.getType().getInput(0).dyn_cast<RankedTensorType>();
78 if (boxes_type == nullptr || !boxes_type.hasRank() ||
79 boxes_type.getRank() != 2) {
80 return func_.emitWarning() << "TFLite does not support batched input for "
81 "non_max_suppression_padded";
82 }
83 return success();
84 }
85
RewriteFunc()86 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
87 func_.eraseBody();
88 func_.addEntryBlock();
89 func_->setAttr(kTFImplements,
90 StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
91
92 OpBuilder builder(func_.getBody());
93 std::string custom_option_buffer;
94 if (failed(CreateNMSCustomOptions(func_, attr_.getAttrs(),
95 custom_option_buffer))) {
96 return failure();
97 }
98 auto op = builder.create<CustomOp>(
99 func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
100 kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
101 builder.create<ReturnOp>(func_.getLoc(), op.getResults());
102
103 return success();
104 }
105
CreateNMSCustomOptions(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)106 LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
107 FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
108 flexbuffers::Builder fbb;
109 size_t start_map = fbb.StartMap();
110
111 if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
112 failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
113 failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
114 failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
115 failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
116 failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
117 failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
118 failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
119 failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
120 return failure();
121 auto use_regular_nms =
122 attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
123 if (!use_regular_nms) {
124 return func.emitError()
125 << "use_regular_nms attribute is not set or not a bool";
126 }
127 fbb.Int("use_regular_nms", use_regular_nms.getValue());
128
129 fbb.EndMap(start_map);
130 fbb.Finish();
131 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
132 return success();
133 }
134
AddIntAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)135 LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
136 FuncOp func, DictionaryAttr attrs, const std::string& attribute,
137 flexbuffers::Builder* builder) {
138 auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
139 if (!int_attr) {
140 return func.emitError()
141 << attribute.c_str() << " attribute is not set or not an integer";
142 }
143 builder->Int(attribute.c_str(), int_attr.getInt());
144 return success();
145 }
146
AddFloatAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)147 LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
148 FuncOp func, DictionaryAttr attrs, const std::string& attribute,
149 flexbuffers::Builder* builder) {
150 auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
151 if (!float_attr) {
152 return func.emitError()
153 << attribute.c_str() << " attribute is not set or not a float";
154 }
155 builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
156 return success();
157 }
158
HasIntAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute)159 LogicalResult ConvertSSDPostProcessFunc::HasIntAttr(
160 FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
161 auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
162 if (!int_attr) {
163 return func.emitWarning()
164 << attribute.c_str() << " attribute is not set or not an integer";
165 }
166 return success();
167 }
168
HasFloatAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute)169 LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr(
170 FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
171 auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
172 if (!float_attr) {
173 return func.emitWarning()
174 << attribute.c_str() << " attribute is not set or not a float";
175 }
176 return success();
177 }
178
VerifySignature()179 LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
180 // Verify high-level function signature.
181 if (func_.getNumArguments() != 3) {
182 return func_.emitWarning()
183 << "Invalid number of arguments to " << kCustomSSDPostprocessing
184 << ": " << func_.getNumArguments();
185 }
186 if (func_.getType().getNumResults() != 4) {
187 return func_.emitWarning()
188 << "Invalid number of results from " << kCustomSSDPostprocessing
189 << ": " << func_.getType().getNumResults();
190 }
191
192 auto attrs = attr_.getAttrs();
193 if (failed(HasIntAttr(func_, attrs, "max_detections")) ||
194 failed(HasIntAttr(func_, attrs, "max_classes_per_detection")) ||
195 failed(HasIntAttr(func_, attrs, "num_classes")) ||
196 failed(HasFloatAttr(func_, attrs, "nms_score_threshold")) ||
197 failed(HasFloatAttr(func_, attrs, "nms_iou_threshold")) ||
198 failed(HasFloatAttr(func_, attrs, "y_scale")) ||
199 failed(HasFloatAttr(func_, attrs, "x_scale")) ||
200 failed(HasFloatAttr(func_, attrs, "h_scale")) ||
201 failed(HasFloatAttr(func_, attrs, "w_scale"))) {
202 return failure();
203 }
204 return success();
205 }
206
207 } // namespace TFL
208 } // namespace mlir
209