1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/detection_post_process.h"
18 #include "ops/op_utils.h"
19 #include "utils/check_convert_utils.h"
20
21 namespace mindspore {
22 namespace ops {
Init(const int64_t inputSize,const std::vector<float> & scale,const float NmsIouThreshold,const float NmsScoreThreshold,const int64_t MaxDetections,const int64_t DetectionsPerClass,const int64_t MaxClassesPerDetection,const int64_t NumClasses,const bool UseRegularNms,const bool OutQuantized,const Format & format)23 void DetectionPostProcess::Init(const int64_t inputSize, const std::vector<float> &scale, const float NmsIouThreshold,
24 const float NmsScoreThreshold, const int64_t MaxDetections,
25 const int64_t DetectionsPerClass, const int64_t MaxClassesPerDetection,
26 const int64_t NumClasses, const bool UseRegularNms, const bool OutQuantized,
27 const Format &format) {
28 set_input_size(inputSize);
29 set_scale(scale);
30 set_nms_iou_threshold(NmsIouThreshold);
31 set_nms_score_threshold(NmsScoreThreshold);
32 set_max_detections(MaxDetections);
33 set_detections_per_class(DetectionsPerClass);
34 set_max_classes_per_detection(MaxClassesPerDetection);
35 set_num_classes(NumClasses);
36 set_use_regular_nms(UseRegularNms);
37 set_out_quantized(OutQuantized);
38 set_format(format);
39 }
40
set_input_size(const int64_t inputSize)41 void DetectionPostProcess::set_input_size(const int64_t inputSize) {
42 (void)this->AddAttr(kInputSize, MakeValue(inputSize));
43 }
44
get_input_size() const45 int64_t DetectionPostProcess::get_input_size() const {
46 auto value_ptr = this->GetAttr(kInputSize);
47 return GetValue<int64_t>(value_ptr);
48 }
49
set_scale(const std::vector<float> & scale)50 void DetectionPostProcess::set_scale(const std::vector<float> &scale) { (void)this->AddAttr(kScale, MakeValue(scale)); }
get_scale() const51 std::vector<float> DetectionPostProcess::get_scale() const {
52 auto value_ptr = this->GetAttr(kScale);
53 return GetValue<std::vector<float>>(value_ptr);
54 }
55
set_nms_iou_threshold(const float NmsIouThreshold)56 void DetectionPostProcess::set_nms_iou_threshold(const float NmsIouThreshold) {
57 (void)this->AddAttr(kNmsIouThreshold, MakeValue(NmsIouThreshold));
58 }
get_nms_iou_threshold() const59 float DetectionPostProcess::get_nms_iou_threshold() const {
60 auto value_ptr = this->GetAttr(kNmsIouThreshold);
61 return GetValue<float>(value_ptr);
62 }
63
set_nms_score_threshold(const float NmsScoreThreshold)64 void DetectionPostProcess::set_nms_score_threshold(const float NmsScoreThreshold) {
65 (void)this->AddAttr(kNmsScoreThreshold, MakeValue(NmsScoreThreshold));
66 }
get_nms_score_threshold() const67 float DetectionPostProcess::get_nms_score_threshold() const {
68 auto value_ptr = this->GetAttr(kNmsScoreThreshold);
69 return GetValue<float>(value_ptr);
70 }
71
set_max_detections(const int64_t MaxDetections)72 void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) {
73 (void)this->AddAttr(kMaxDetections, MakeValue(MaxDetections));
74 }
get_max_detections() const75 int64_t DetectionPostProcess::get_max_detections() const { return GetValue<int64_t>(GetAttr(kMaxDetections)); }
76
set_detections_per_class(const int64_t DetectionsPerClass)77 void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) {
78 (void)this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass));
79 }
get_detections_per_class() const80 int64_t DetectionPostProcess::get_detections_per_class() const {
81 auto value_ptr = this->GetAttr(kDetectionsPerClass);
82 return GetValue<int64_t>(value_ptr);
83 }
84
set_max_classes_per_detection(const int64_t MaxClassesPerDetection)85 void DetectionPostProcess::set_max_classes_per_detection(const int64_t MaxClassesPerDetection) {
86 (void)this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection));
87 }
get_max_classes_per_detection() const88 int64_t DetectionPostProcess::get_max_classes_per_detection() const {
89 return GetValue<int64_t>(GetAttr(kMaxClassesPerDetection));
90 }
91
set_num_classes(const int64_t NumClasses)92 void DetectionPostProcess::set_num_classes(const int64_t NumClasses) {
93 (void)this->AddAttr(kNumClasses, MakeValue(NumClasses));
94 }
get_num_classes() const95 int64_t DetectionPostProcess::get_num_classes() const { return GetValue<int64_t>(GetAttr(kNumClasses)); }
set_use_regular_nms(const bool UseRegularNms)96 void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) {
97 (void)this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms));
98 }
get_use_regular_nms() const99 bool DetectionPostProcess::get_use_regular_nms() const {
100 auto value_ptr = this->GetAttr(kUseRegularNms);
101 return GetValue<bool>(value_ptr);
102 }
103
set_out_quantized(const bool OutQuantized)104 void DetectionPostProcess::set_out_quantized(const bool OutQuantized) {
105 (void)this->AddAttr(kOutQuantized, MakeValue(OutQuantized));
106 }
get_out_quantized() const107 bool DetectionPostProcess::get_out_quantized() const {
108 auto value_ptr = this->GetAttr(kOutQuantized);
109 return GetValue<bool>(value_ptr);
110 }
set_format(const Format & format)111 void DetectionPostProcess::set_format(const Format &format) {
112 int64_t f = format;
113 (void)this->AddAttr(kFormat, MakeValue(f));
114 }
get_format() const115 Format DetectionPostProcess::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)116 AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
117 const std::vector<AbstractBasePtr> &input_args) {
118 MS_EXCEPTION_IF_NULL(primitive);
119 auto prim_name = primitive->name();
120 const int64_t input_num = 3;
121 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
122 MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
123 MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
124 MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
125 auto boxes = input_args[kInputIndex0];
126 auto scores = input_args[kInputIndex1];
127 auto anchors = input_args[kInputIndex2];
128 auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
129 auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
130 auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];
131 auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
132 if (format == NHWC) {
133 boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]};
134 scores_shape = {scores_shape[0], scores_shape[3], scores_shape[1], scores_shape[2]};
135 anchors_shape = {anchors_shape[0], anchors_shape[3], anchors_shape[1], anchors_shape[2]};
136 }
137 auto num_classes = GetValue<int64_t>(primitive->GetAttr(kNumClasses));
138 CheckAndConvertUtils::CheckInRange("scores_shape[2]", scores_shape[2], kIncludeBoth, {num_classes, num_classes + 1},
139 prim_name);
140 CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "scores_shape[1]", scores_shape[1], prim_name,
141 ValueError);
142 CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "anchors_shape[0]", anchors_shape[0], prim_name,
143 ValueError);
144
145 // Infer shape
146 auto max_detections = GetValue<int64_t>(primitive->GetAttr(kMaxDetections));
147 auto max_classes_per_detection = GetValue<int64_t>(primitive->GetAttr(kMaxClassesPerDetection));
148 auto num_detected_boxes = max_detections * max_classes_per_detection;
149 std::vector<int64_t> output_boxes_shape = {1, num_detected_boxes, 4};
150 std::vector<int64_t> output_class_shape = {1, num_detected_boxes};
151 std::vector<int64_t> output_num_shape = {1};
152
153 // Infer type
154 auto output_type = kFloat32;
155
156 auto output0 = std::make_shared<abstract::AbstractTensor>(output_type, output_boxes_shape);
157 auto output1 = std::make_shared<abstract::AbstractTensor>(output_type, output_class_shape);
158 auto output2 = std::make_shared<abstract::AbstractTensor>(output_type, output_num_shape);
159 AbstractBasePtrList output = {output0, output1, output1, output2};
160 if (format == NHWC) {
161 output = {output0, output1, output2, output1};
162 }
163 return std::make_shared<abstract::AbstractTuple>(output);
164 }
165 REGISTER_PRIMITIVE_C(kNameDetectionPostProcess, DetectionPostProcess);
166 } // namespace ops
167 } // namespace mindspore
168