• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/infer/detection_post_process_infer.h"
18 #include "nnacl/infer/infer_register.h"
19 
DetectionPostProcessInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)20 int DetectionPostProcessInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs,
21                                    size_t outputs_size, OpParameter *parameter) {
22   int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 3, 4);
23   if (check_ret != NNACL_OK) {
24     return check_ret;
25   }
26 
27   const TensorC *boxes = inputs[0];
28   const TensorC *scores = inputs[1];
29   const TensorC *anchors = inputs[2];
30   if (boxes->shape_size_ < 2 || scores->shape_size_ < 3 || anchors->shape_size_ < 1) {
31     return NNACL_INPUT_TENSOR_ERROR;
32   }
33 
34   DetectionPostProcessParameter *param = (DetectionPostProcessParameter *)parameter;
35   if (scores->shape_[2] < param->num_classes_) {
36     return NNACL_ERR;
37   }
38   if (scores->shape_[2] - param->num_classes_ > 1) {
39     return NNACL_ERR;
40   }
41   if (boxes->shape_[1] != scores->shape_[1]) {
42     return NNACL_ERR;
43   }
44   if (boxes->shape_[1] != anchors->shape_[0]) {
45     return NNACL_ERR;
46   }
47 
48   TensorC *detected_boxes = outputs[0];
49   TensorC *detected_classes = outputs[1];
50   TensorC *detected_scores = outputs[2];
51   TensorC *num_det = outputs[3];
52 
53   detected_boxes->format_ = boxes->format_;
54   detected_boxes->data_type_ = kNumberTypeFloat32;
55   detected_classes->format_ = boxes->format_;
56   detected_classes->data_type_ = kNumberTypeFloat32;
57   detected_scores->format_ = boxes->format_;
58   detected_scores->data_type_ = kNumberTypeFloat32;
59   num_det->format_ = boxes->format_;
60   num_det->data_type_ = kNumberTypeFloat32;
61   if (!InferFlag(inputs, inputs_size)) {
62     return NNACL_INFER_INVALID;
63   }
64   const int max_detections = param->max_detections_;
65   const int max_classes_per_detection = param->max_classes_per_detection_;
66   const int num_detected_boxes = (int)(max_detections * max_classes_per_detection);
67   detected_boxes->shape_size_ = 3;
68   detected_boxes->shape_[0] = 1;
69   detected_boxes->shape_[1] = num_detected_boxes;
70   detected_boxes->shape_[2] = 4;
71   detected_classes->shape_size_ = 2;
72   detected_classes->shape_[0] = 1;
73   detected_classes->shape_[1] = num_detected_boxes;
74   detected_scores->shape_size_ = 2;
75   detected_scores->shape_[0] = 1;
76   detected_scores->shape_[1] = num_detected_boxes;
77   num_det->shape_size_ = 1;
78   num_det->shape_[0] = 1;
79 
80   return NNACL_OK;
81 }
82 
83 REG_INFER(DetectionPostProcess, PrimType_DetectionPostProcess, DetectionPostProcessInferShape)
84