• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "ops/roi_pooling.h"
17 #include <string>
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include "ops/op_utils.h"
23 #include "utils/check_convert_utils.h"
24 #include "abstract/primitive_infer_map.h"
25 
26 namespace mindspore {
27 namespace ops {
set_pooled_h(const int64_t pooled_h)28 void ROIPooling::set_pooled_h(const int64_t pooled_h) { (void)this->AddAttr(kPooledH, MakeValue(pooled_h)); }
29 
get_pooled_h() const30 int64_t ROIPooling::get_pooled_h() const { return GetValue<int64_t>(GetAttr(kPooledH)); }
31 
set_pooled_w(const int64_t pooled_w)32 void ROIPooling::set_pooled_w(const int64_t pooled_w) { (void)this->AddAttr(kPooledW, MakeValue(pooled_w)); }
33 
get_pooled_w() const34 int64_t ROIPooling::get_pooled_w() const {
35   auto value_ptr = GetAttr(kPooledW);
36   return GetValue<int64_t>(value_ptr);
37 }
38 
set_scale(const float scale)39 void ROIPooling::set_scale(const float scale) { (void)this->AddAttr(kScale, MakeValue(scale)); }
40 
get_scale() const41 float ROIPooling::get_scale() const {
42   auto value_ptr = GetAttr(kScale);
43   return GetValue<float>(value_ptr);
44 }
45 
Init(const int64_t pooled_h,const int64_t pooled_w,const float scale)46 void ROIPooling::Init(const int64_t pooled_h, const int64_t pooled_w, const float scale) {
47   this->set_pooled_h(pooled_h);
48   this->set_pooled_w(pooled_w);
49   this->set_scale(scale);
50 }
ROIPoolingInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)51 AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
52                                 const std::vector<AbstractBasePtr> &input_args) {
53   MS_EXCEPTION_IF_NULL(primitive);
54   auto prim_name = primitive->name();
55   const int64_t input_num = 2;
56   (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
57   MS_EXCEPTION_IF_NULL(input_args[0]);
58   MS_EXCEPTION_IF_NULL(input_args[1]);
59 
60   // Infer type
61   auto output_data_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
62 
63   // Infer shape
64   auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH));
65   auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW));
66   auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
67   auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
68   std::vector<int64_t> output_shape;
69   output_shape.push_back(roi_shape[0]);
70   output_shape.push_back(new_h);
71   output_shape.push_back(new_w);
72   output_shape.push_back(input_shape[1]);
73 
74   return std::make_shared<abstract::AbstractTensor>(output_data_type, std::make_shared<abstract::Shape>(output_shape));
75 }
76 REGISTER_PRIMITIVE_C(kNameROIPooling, ROIPooling);
77 }  // namespace ops
78 }  // namespace mindspore
79