• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/detection_post_process.h"
18 #include "ops/op_utils.h"
19 #include "utils/check_convert_utils.h"
20 
21 namespace mindspore {
22 namespace ops {
Init(const int64_t inputSize,const std::vector<float> & scale,const float NmsIouThreshold,const float NmsScoreThreshold,const int64_t MaxDetections,const int64_t DetectionsPerClass,const int64_t MaxClassesPerDetection,const int64_t NumClasses,const bool UseRegularNms,const bool OutQuantized,const Format & format)23 void DetectionPostProcess::Init(const int64_t inputSize, const std::vector<float> &scale, const float NmsIouThreshold,
24                                 const float NmsScoreThreshold, const int64_t MaxDetections,
25                                 const int64_t DetectionsPerClass, const int64_t MaxClassesPerDetection,
26                                 const int64_t NumClasses, const bool UseRegularNms, const bool OutQuantized,
27                                 const Format &format) {
28   set_input_size(inputSize);
29   set_scale(scale);
30   set_nms_iou_threshold(NmsIouThreshold);
31   set_nms_score_threshold(NmsScoreThreshold);
32   set_max_detections(MaxDetections);
33   set_detections_per_class(DetectionsPerClass);
34   set_max_classes_per_detection(MaxClassesPerDetection);
35   set_num_classes(NumClasses);
36   set_use_regular_nms(UseRegularNms);
37   set_out_quantized(OutQuantized);
38   set_format(format);
39 }
40 
set_input_size(const int64_t inputSize)41 void DetectionPostProcess::set_input_size(const int64_t inputSize) {
42   (void)this->AddAttr(kInputSize, MakeValue(inputSize));
43 }
44 
get_input_size() const45 int64_t DetectionPostProcess::get_input_size() const {
46   auto value_ptr = this->GetAttr(kInputSize);
47   return GetValue<int64_t>(value_ptr);
48 }
49 
set_scale(const std::vector<float> & scale)50 void DetectionPostProcess::set_scale(const std::vector<float> &scale) { (void)this->AddAttr(kScale, MakeValue(scale)); }
get_scale() const51 std::vector<float> DetectionPostProcess::get_scale() const {
52   auto value_ptr = this->GetAttr(kScale);
53   return GetValue<std::vector<float>>(value_ptr);
54 }
55 
set_nms_iou_threshold(const float NmsIouThreshold)56 void DetectionPostProcess::set_nms_iou_threshold(const float NmsIouThreshold) {
57   (void)this->AddAttr(kNmsIouThreshold, MakeValue(NmsIouThreshold));
58 }
get_nms_iou_threshold() const59 float DetectionPostProcess::get_nms_iou_threshold() const {
60   auto value_ptr = this->GetAttr(kNmsIouThreshold);
61   return GetValue<float>(value_ptr);
62 }
63 
set_nms_score_threshold(const float NmsScoreThreshold)64 void DetectionPostProcess::set_nms_score_threshold(const float NmsScoreThreshold) {
65   (void)this->AddAttr(kNmsScoreThreshold, MakeValue(NmsScoreThreshold));
66 }
get_nms_score_threshold() const67 float DetectionPostProcess::get_nms_score_threshold() const {
68   auto value_ptr = this->GetAttr(kNmsScoreThreshold);
69   return GetValue<float>(value_ptr);
70 }
71 
set_max_detections(const int64_t MaxDetections)72 void DetectionPostProcess::set_max_detections(const int64_t MaxDetections) {
73   (void)this->AddAttr(kMaxDetections, MakeValue(MaxDetections));
74 }
get_max_detections() const75 int64_t DetectionPostProcess::get_max_detections() const { return GetValue<int64_t>(GetAttr(kMaxDetections)); }
76 
set_detections_per_class(const int64_t DetectionsPerClass)77 void DetectionPostProcess::set_detections_per_class(const int64_t DetectionsPerClass) {
78   (void)this->AddAttr(kDetectionsPerClass, MakeValue(DetectionsPerClass));
79 }
get_detections_per_class() const80 int64_t DetectionPostProcess::get_detections_per_class() const {
81   auto value_ptr = this->GetAttr(kDetectionsPerClass);
82   return GetValue<int64_t>(value_ptr);
83 }
84 
set_max_classes_per_detection(const int64_t MaxClassesPerDetection)85 void DetectionPostProcess::set_max_classes_per_detection(const int64_t MaxClassesPerDetection) {
86   (void)this->AddAttr(kMaxClassesPerDetection, MakeValue(MaxClassesPerDetection));
87 }
get_max_classes_per_detection() const88 int64_t DetectionPostProcess::get_max_classes_per_detection() const {
89   return GetValue<int64_t>(GetAttr(kMaxClassesPerDetection));
90 }
91 
set_num_classes(const int64_t NumClasses)92 void DetectionPostProcess::set_num_classes(const int64_t NumClasses) {
93   (void)this->AddAttr(kNumClasses, MakeValue(NumClasses));
94 }
get_num_classes() const95 int64_t DetectionPostProcess::get_num_classes() const { return GetValue<int64_t>(GetAttr(kNumClasses)); }
set_use_regular_nms(const bool UseRegularNms)96 void DetectionPostProcess::set_use_regular_nms(const bool UseRegularNms) {
97   (void)this->AddAttr(kUseRegularNms, MakeValue(UseRegularNms));
98 }
get_use_regular_nms() const99 bool DetectionPostProcess::get_use_regular_nms() const {
100   auto value_ptr = this->GetAttr(kUseRegularNms);
101   return GetValue<bool>(value_ptr);
102 }
103 
set_out_quantized(const bool OutQuantized)104 void DetectionPostProcess::set_out_quantized(const bool OutQuantized) {
105   (void)this->AddAttr(kOutQuantized, MakeValue(OutQuantized));
106 }
get_out_quantized() const107 bool DetectionPostProcess::get_out_quantized() const {
108   auto value_ptr = this->GetAttr(kOutQuantized);
109   return GetValue<bool>(value_ptr);
110 }
set_format(const Format & format)111 void DetectionPostProcess::set_format(const Format &format) {
112   int64_t f = format;
113   (void)this->AddAttr(kFormat, MakeValue(f));
114 }
get_format() const115 Format DetectionPostProcess::get_format() const { return Format(GetValue<int64_t>(GetAttr(kFormat))); }
DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)116 AbstractBasePtr DetectionPostProcessInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
117                                           const std::vector<AbstractBasePtr> &input_args) {
118   MS_EXCEPTION_IF_NULL(primitive);
119   auto prim_name = primitive->name();
120   const int64_t input_num = 3;
121   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
122   MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
123   MS_EXCEPTION_IF_NULL(input_args[kInputIndex1]);
124   MS_EXCEPTION_IF_NULL(input_args[kInputIndex2]);
125   auto boxes = input_args[kInputIndex0];
126   auto scores = input_args[kInputIndex1];
127   auto anchors = input_args[kInputIndex2];
128   auto boxes_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(boxes->BuildShape())[kShape];
129   auto scores_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(scores->BuildShape())[kShape];
130   auto anchors_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(anchors->BuildShape())[kShape];
131   auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
132   if (format == NHWC) {
133     boxes_shape = {boxes_shape[0], boxes_shape[3], boxes_shape[1], boxes_shape[2]};
134     scores_shape = {scores_shape[0], scores_shape[3], scores_shape[1], scores_shape[2]};
135     anchors_shape = {anchors_shape[0], anchors_shape[3], anchors_shape[1], anchors_shape[2]};
136   }
137   auto num_classes = GetValue<int64_t>(primitive->GetAttr(kNumClasses));
138   CheckAndConvertUtils::CheckInRange("scores_shape[2]", scores_shape[2], kIncludeBoth, {num_classes, num_classes + 1},
139                                      prim_name);
140   CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "scores_shape[1]", scores_shape[1], prim_name,
141                               ValueError);
142   CheckAndConvertUtils::Check("boxes_shape[1]", boxes_shape[1], kEqual, "anchors_shape[0]", anchors_shape[0], prim_name,
143                               ValueError);
144 
145   // Infer shape
146   auto max_detections = GetValue<int64_t>(primitive->GetAttr(kMaxDetections));
147   auto max_classes_per_detection = GetValue<int64_t>(primitive->GetAttr(kMaxClassesPerDetection));
148   auto num_detected_boxes = max_detections * max_classes_per_detection;
149   std::vector<int64_t> output_boxes_shape = {1, num_detected_boxes, 4};
150   std::vector<int64_t> output_class_shape = {1, num_detected_boxes};
151   std::vector<int64_t> output_num_shape = {1};
152 
153   // Infer type
154   auto output_type = kFloat32;
155 
156   auto output0 = std::make_shared<abstract::AbstractTensor>(output_type, output_boxes_shape);
157   auto output1 = std::make_shared<abstract::AbstractTensor>(output_type, output_class_shape);
158   auto output2 = std::make_shared<abstract::AbstractTensor>(output_type, output_num_shape);
159   AbstractBasePtrList output = {output0, output1, output1, output2};
160   if (format == NHWC) {
161     output = {output0, output1, output2, output1};
162   }
163   return std::make_shared<abstract::AbstractTuple>(output);
164 }
165 REGISTER_PRIMITIVE_C(kNameDetectionPostProcess, DetectionPostProcess);
166 }  // namespace ops
167 }  // namespace mindspore
168