1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "ops/roi_pooling.h"
17 #include <string>
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include "ops/op_utils.h"
23 #include "utils/check_convert_utils.h"
24 #include "abstract/primitive_infer_map.h"
25
26 namespace mindspore {
27 namespace ops {
set_pooled_h(const int64_t pooled_h)28 void ROIPooling::set_pooled_h(const int64_t pooled_h) { (void)this->AddAttr(kPooledH, MakeValue(pooled_h)); }
29
get_pooled_h() const30 int64_t ROIPooling::get_pooled_h() const { return GetValue<int64_t>(GetAttr(kPooledH)); }
31
set_pooled_w(const int64_t pooled_w)32 void ROIPooling::set_pooled_w(const int64_t pooled_w) { (void)this->AddAttr(kPooledW, MakeValue(pooled_w)); }
33
get_pooled_w() const34 int64_t ROIPooling::get_pooled_w() const {
35 auto value_ptr = GetAttr(kPooledW);
36 return GetValue<int64_t>(value_ptr);
37 }
38
set_scale(const float scale)39 void ROIPooling::set_scale(const float scale) { (void)this->AddAttr(kScale, MakeValue(scale)); }
40
get_scale() const41 float ROIPooling::get_scale() const {
42 auto value_ptr = GetAttr(kScale);
43 return GetValue<float>(value_ptr);
44 }
45
Init(const int64_t pooled_h,const int64_t pooled_w,const float scale)46 void ROIPooling::Init(const int64_t pooled_h, const int64_t pooled_w, const float scale) {
47 this->set_pooled_h(pooled_h);
48 this->set_pooled_w(pooled_w);
49 this->set_scale(scale);
50 }
ROIPoolingInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)51 AbstractBasePtr ROIPoolingInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
52 const std::vector<AbstractBasePtr> &input_args) {
53 MS_EXCEPTION_IF_NULL(primitive);
54 auto prim_name = primitive->name();
55 const int64_t input_num = 2;
56 (void)CheckAndConvertUtils::CheckInteger("infer", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
57 MS_EXCEPTION_IF_NULL(input_args[0]);
58 MS_EXCEPTION_IF_NULL(input_args[1]);
59
60 // Infer type
61 auto output_data_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
62
63 // Infer shape
64 auto new_h = GetValue<int64_t>(primitive->GetAttr(kPooledH));
65 auto new_w = GetValue<int64_t>(primitive->GetAttr(kPooledW));
66 auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
67 auto roi_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
68 std::vector<int64_t> output_shape;
69 output_shape.push_back(roi_shape[0]);
70 output_shape.push_back(new_h);
71 output_shape.push_back(new_w);
72 output_shape.push_back(input_shape[1]);
73
74 return std::make_shared<abstract::AbstractTensor>(output_data_type, std::make_shared<abstract::Shape>(output_shape));
75 }
76 REGISTER_PRIMITIVE_C(kNameROIPooling, ROIPooling);
77 } // namespace ops
78 } // namespace mindspore
79