1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/nms_utils.h"
17 
18 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
20 
21 namespace mlir {
22 namespace TFL {
23 
24 namespace {
25 
26 // TODO(b/162842801): Consolidate all util definitions of kTFImplements.
27 constexpr char kTFImplements[] = "tf._implements";
28 constexpr char kCustomSSDPostprocessing[] = "TFLite_Detection_PostProcess";
29 constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2";
30 
CustomOption(OpBuilder * builder,const std::string & content)31 inline OpaqueElementsAttr CustomOption(OpBuilder* builder,
32                                        const std::string& content) {
33   ShapedType type = RankedTensorType::get(
34       {static_cast<int64_t>(content.size())}, builder->getIntegerType(8));
35   return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"),
36                                  type,
37                                  StringRef(content.data(), content.size()));
38 }
39 
40 }  // namespace
41 
RewriteFunc()42 void ConvertNMSPaddedFunc::RewriteFunc() {
43   func_->setAttr(kTFImplements,
44                  StringAttr::get(func_.getContext(), kTfNMSPadded));
45   Value boxes = func_.getArgument(0);
46   Value scores = func_.getArgument(1);
47   Value max_output_size = func_.getArgument(2);
48   Value iou_threshold = func_.getArgument(3);
49   Value score_threshold = func_.getArgument(4);
50   auto output_type0 = func_.getType().getResult(0);
51   auto output_type1 = func_.getType().getResult(1);
52 
53   OpBuilder builder(func_.getBody());
54   auto op = builder.create<mlir::TFL::NonMaxSuppressionV4Op>(
55       func_.getLoc(), output_type0, output_type1, boxes, scores,
56       max_output_size, iou_threshold, score_threshold);
57 
58   builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResults());
59 }
60 
VerifySignature()61 LogicalResult ConvertNMSPaddedFunc::VerifySignature() {
62   // Verify high-level function signature.
63   // Relevant argument characteristics are checked by the TFL op definition.
64   if (func_.getNumArguments() < 5) {
65     return func_.emitWarning()
66            << "Invalid number of arguments to "
67               "non_max_suppression_padded_v2 (need at least 5): "
68            << func_.getNumArguments();
69   }
70   if (func_.getType().getNumResults() != 2) {
71     return func_.emitWarning() << "Invalid number of results from "
72                                   "non_max_suppression_padded_v2 (need 2): "
73                                << func_.getType().getNumResults();
74   }
75   // The TFLite fused op does not support batching yet.
76   // TODO(b/158709815): Add support for batches with padded NMS.
77   auto boxes_type = func_.getType().getInput(0).dyn_cast<RankedTensorType>();
78   if (boxes_type == nullptr || !boxes_type.hasRank() ||
79       boxes_type.getRank() != 2) {
80     return func_.emitWarning() << "TFLite does not support batched input for "
81                                   "non_max_suppression_padded";
82   }
83   return success();
84 }
85 
RewriteFunc()86 LogicalResult ConvertSSDPostProcessFunc::RewriteFunc() {
87   func_.eraseBody();
88   func_.addEntryBlock();
89   func_->setAttr(kTFImplements,
90                  StringAttr::get(func_.getContext(), kCustomSSDPostprocessing));
91 
92   OpBuilder builder(func_.getBody());
93   std::string custom_option_buffer;
94   if (failed(CreateNMSCustomOptions(func_, attr_.getAttrs(),
95                                     custom_option_buffer))) {
96     return failure();
97   }
98   auto op = builder.create<CustomOp>(
99       func_.getLoc(), func_.getType().getResults(), func_.getArguments(),
100       kCustomSSDPostprocessing, CustomOption(&builder, custom_option_buffer));
101   builder.create<ReturnOp>(func_.getLoc(), op.getResults());
102 
103   return success();
104 }
105 
CreateNMSCustomOptions(FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)106 LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions(
107     FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
108   flexbuffers::Builder fbb;
109   size_t start_map = fbb.StartMap();
110 
111   if (failed(AddIntAttr(func, attrs, "max_detections", &fbb)) ||
112       failed(AddIntAttr(func, attrs, "max_classes_per_detection", &fbb)) ||
113       failed(AddIntAttr(func, attrs, "num_classes", &fbb)) ||
114       failed(AddFloatAttr(func, attrs, "nms_score_threshold", &fbb)) ||
115       failed(AddFloatAttr(func, attrs, "nms_iou_threshold", &fbb)) ||
116       failed(AddFloatAttr(func, attrs, "y_scale", &fbb)) ||
117       failed(AddFloatAttr(func, attrs, "x_scale", &fbb)) ||
118       failed(AddFloatAttr(func, attrs, "h_scale", &fbb)) ||
119       failed(AddFloatAttr(func, attrs, "w_scale", &fbb)))
120     return failure();
121   auto use_regular_nms =
122       attrs.get("use_regular_nms").dyn_cast_or_null<BoolAttr>();
123   if (!use_regular_nms) {
124     return func.emitError()
125            << "use_regular_nms attribute is not set or not a bool";
126   }
127   fbb.Int("use_regular_nms", use_regular_nms.getValue());
128 
129   fbb.EndMap(start_map);
130   fbb.Finish();
131   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
132   return success();
133 }
134 
AddIntAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)135 LogicalResult ConvertSSDPostProcessFunc::AddIntAttr(
136     FuncOp func, DictionaryAttr attrs, const std::string& attribute,
137     flexbuffers::Builder* builder) {
138   auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
139   if (!int_attr) {
140     return func.emitError()
141            << attribute.c_str() << " attribute is not set or not an integer";
142   }
143   builder->Int(attribute.c_str(), int_attr.getInt());
144   return success();
145 }
146 
AddFloatAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute,flexbuffers::Builder * builder)147 LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr(
148     FuncOp func, DictionaryAttr attrs, const std::string& attribute,
149     flexbuffers::Builder* builder) {
150   auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
151   if (!float_attr) {
152     return func.emitError()
153            << attribute.c_str() << " attribute is not set or not a float";
154   }
155   builder->Float(attribute.c_str(), float_attr.getValue().convertToFloat());
156   return success();
157 }
158 
HasIntAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute)159 LogicalResult ConvertSSDPostProcessFunc::HasIntAttr(
160     FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
161   auto int_attr = attrs.get(attribute).dyn_cast_or_null<IntegerAttr>();
162   if (!int_attr) {
163     return func.emitWarning()
164            << attribute.c_str() << " attribute is not set or not an integer";
165   }
166   return success();
167 }
168 
HasFloatAttr(FuncOp func,DictionaryAttr attrs,const std::string & attribute)169 LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr(
170     FuncOp func, DictionaryAttr attrs, const std::string& attribute) {
171   auto float_attr = attrs.get(attribute).dyn_cast_or_null<FloatAttr>();
172   if (!float_attr) {
173     return func.emitWarning()
174            << attribute.c_str() << " attribute is not set or not a float";
175   }
176   return success();
177 }
178 
VerifySignature()179 LogicalResult ConvertSSDPostProcessFunc::VerifySignature() {
180   // Verify high-level function signature.
181   if (func_.getNumArguments() != 3) {
182     return func_.emitWarning()
183            << "Invalid number of arguments to " << kCustomSSDPostprocessing
184            << ": " << func_.getNumArguments();
185   }
186   if (func_.getType().getNumResults() != 4) {
187     return func_.emitWarning()
188            << "Invalid number of results from " << kCustomSSDPostprocessing
189            << ": " << func_.getType().getNumResults();
190   }
191 
192   auto attrs = attr_.getAttrs();
193   if (failed(HasIntAttr(func_, attrs, "max_detections")) ||
194       failed(HasIntAttr(func_, attrs, "max_classes_per_detection")) ||
195       failed(HasIntAttr(func_, attrs, "num_classes")) ||
196       failed(HasFloatAttr(func_, attrs, "nms_score_threshold")) ||
197       failed(HasFloatAttr(func_, attrs, "nms_iou_threshold")) ||
198       failed(HasFloatAttr(func_, attrs, "y_scale")) ||
199       failed(HasFloatAttr(func_, attrs, "x_scale")) ||
200       failed(HasFloatAttr(func_, attrs, "h_scale")) ||
201       failed(HasFloatAttr(func_, attrs, "w_scale"))) {
202     return failure();
203   }
204   return success();
205 }
206 
207 }  // namespace TFL
208 }  // namespace mlir
209