1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/resize.h"
18 #include "abstract/ops/primitive_infer_map.h"
19 #include "mindapi/base/shared_ptr.h"
20 #include "mindapi/ir/value.h"
21 #include "mindapi/src/helper.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "ops/op_name.h"
24 #include "ops/primitive_c.h"
25 #include "utils/check_convert_utils.h"
26 #include "utils/log_adapter.h"
27 #include "ops/op_utils.h"
28
29 namespace mindspore {
30 namespace ops {
31 MIND_API_OPERATOR_IMPL(Resize, BaseOperator);
Init(const Format format,const ResizeMethod method,const int64_t new_height,const int64_t new_width,const bool preserve_aspect_ratio,const CoordinateTransformMode coordinate_transform_mode,const float cubic_coeff,const int64_t exclude_outside,const float extrapolation_value,const NearestMode nearest_mode)32 void Resize::Init(const Format format, const ResizeMethod method, const int64_t new_height, const int64_t new_width,
33 const bool preserve_aspect_ratio, const CoordinateTransformMode coordinate_transform_mode,
34 const float cubic_coeff, const int64_t exclude_outside, const float extrapolation_value,
35 const NearestMode nearest_mode) {
36 this->set_format(format);
37 this->set_method(method);
38 this->set_new_height(new_height);
39 this->set_new_width(new_width);
40 this->set_preserve_aspect_ratio(preserve_aspect_ratio);
41 this->set_coordinate_transform_mode(coordinate_transform_mode);
42 this->set_cubic_coeff(cubic_coeff);
43 this->set_exclude_outside(exclude_outside);
44 this->set_extrapolation_value(extrapolation_value);
45 this->set_nearest_mode(nearest_mode);
46 }
set_format(const Format format)47 void Resize::set_format(const Format format) {
48 int64_t swi = format;
49 (void)this->AddAttr(kFormat, api::MakeValue(swi));
50 }
51
set_method(const ResizeMethod method)52 void Resize::set_method(const ResizeMethod method) {
53 auto swi = static_cast<int64_t>(method);
54 (void)this->AddAttr(kMethod, api::MakeValue(swi));
55 }
56
set_new_height(const int64_t new_height)57 void Resize::set_new_height(const int64_t new_height) { (void)this->AddAttr(kNewHeight, api::MakeValue(new_height)); }
58
set_new_width(const int64_t new_width)59 void Resize::set_new_width(const int64_t new_width) { (void)this->AddAttr(kNewWidth, api::MakeValue(new_width)); }
60
set_preserve_aspect_ratio(const bool preserve_aspect_ratio)61 void Resize::set_preserve_aspect_ratio(const bool preserve_aspect_ratio) {
62 (void)this->AddAttr(kPreserveAspectRatio, api::MakeValue(preserve_aspect_ratio));
63 }
64
set_coordinate_transform_mode(const CoordinateTransformMode coordinate_transform_mode)65 void Resize::set_coordinate_transform_mode(const CoordinateTransformMode coordinate_transform_mode) {
66 int64_t swi = coordinate_transform_mode;
67 (void)this->AddAttr(kCoordinateTransformMode, api::MakeValue(swi));
68 }
69
set_cubic_coeff(const float cubic_coeff)70 void Resize::set_cubic_coeff(const float cubic_coeff) { (void)this->AddAttr(kCubicCoeff, api::MakeValue(cubic_coeff)); }
71
set_exclude_outside(const int64_t exclude_outside)72 void Resize::set_exclude_outside(const int64_t exclude_outside) {
73 (void)this->AddAttr(kExcludeOutside, api::MakeValue(exclude_outside));
74 }
75
set_extrapolation_value(const float extrapolation_value)76 void Resize::set_extrapolation_value(const float extrapolation_value) {
77 (void)this->AddAttr(kExtrapolationValue, api::MakeValue(extrapolation_value));
78 }
79
set_nearest_mode(const NearestMode nearest_mode)80 void Resize::set_nearest_mode(const NearestMode nearest_mode) {
81 int64_t swi = static_cast<int64_t>(nearest_mode);
82 (void)this->AddAttr(kNearestMode, api::MakeValue(swi));
83 }
84
get_format() const85 Format Resize::get_format() const {
86 auto value_ptr = GetAttr(kFormat);
87 return Format(GetValue<int64_t>(value_ptr));
88 }
89
get_method() const90 ResizeMethod Resize::get_method() const {
91 auto value_ptr = GetAttr(kMethod);
92 return ResizeMethod(GetValue<int64_t>(value_ptr));
93 }
94
get_new_height() const95 int64_t Resize::get_new_height() const {
96 auto value_ptr = GetAttr(kNewHeight);
97 return GetValue<int64_t>(value_ptr);
98 }
99
get_new_width() const100 int64_t Resize::get_new_width() const {
101 auto value_ptr = GetAttr(kNewWidth);
102 return GetValue<int64_t>(value_ptr);
103 }
get_preserve_aspect_ratio() const104 bool Resize::get_preserve_aspect_ratio() const {
105 auto value_ptr = GetAttr(kPreserveAspectRatio);
106 return GetValue<bool>(value_ptr);
107 }
get_coordinate_transform_mode() const108 CoordinateTransformMode Resize::get_coordinate_transform_mode() const {
109 auto value_ptr = GetAttr(kCoordinateTransformMode);
110 return CoordinateTransformMode(GetValue<int64_t>(value_ptr));
111 }
112
get_cubic_coeff() const113 float Resize::get_cubic_coeff() const {
114 auto value_ptr = GetAttr(kCubicCoeff);
115 return GetValue<float>(value_ptr);
116 }
117
get_exclude_outside() const118 int64_t Resize::get_exclude_outside() const {
119 auto value_ptr = GetAttr(kExcludeOutside);
120 return GetValue<int64_t>(value_ptr);
121 }
122
get_extrapolation_value() const123 float Resize::get_extrapolation_value() const {
124 auto value_ptr = GetAttr(kExtrapolationValue);
125 return GetValue<float>(value_ptr);
126 }
127
get_nearest_mode() const128 NearestMode Resize::get_nearest_mode() const {
129 auto value_ptr = GetAttr(kNearestMode);
130 return NearestMode(GetValue<int64_t>(value_ptr));
131 }
132
133 namespace {
134 template <typename T>
CheckArrayValueValid(const std::optional<ArrayValue<T>> & input_opt,int64_t * new_height,int64_t * new_width)135 bool CheckArrayValueValid(const std::optional<ArrayValue<T>> &input_opt, int64_t *new_height, int64_t *new_width) {
136 if (input_opt.has_value()) {
137 const auto &input_array = input_opt.value();
138 if (input_array.size() > 0) {
139 return true;
140 }
141 }
142 *new_height = -1;
143 *new_width = -1;
144 return false;
145 }
146
147 constexpr size_t kResizeInputSize = 2;
GetNewHeightAndWidth(const PrimitivePtr & primitive,const AbstractBasePtr & shape_abstract,const int64_t & in_height,const int64_t & in_width,int64_t * new_height,int64_t * new_width)148 void GetNewHeightAndWidth(const PrimitivePtr &primitive, const AbstractBasePtr &shape_abstract,
149 const int64_t &in_height, const int64_t &in_width, int64_t *new_height, int64_t *new_width) {
150 if (!CheckAndConvertUtils::IsTensor(shape_abstract)) {
151 MS_LOG(EXCEPTION) << "For Resize, the inputs[1] must be a tensor, but got: " << shape_abstract->ToString() << ".";
152 }
153 auto shape_value = shape_abstract->GetValue();
154 auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_abstract->GetShape())[kShape];
155 (void)CheckAndConvertUtils::CheckInteger("size dimension", SizeToLong(size_shape.size()), kEqual, 1,
156 primitive->name());
157 auto input_dtype = shape_abstract->GetType()->cast<TensorTypePtr>();
158 MS_EXCEPTION_IF_NULL(input_dtype);
159 auto tensor_type = input_dtype->element();
160 if (size_shape[0] == 1) {
161 // zoom factor
162 (void)CheckAndConvertUtils::CheckTypeValid("size", tensor_type, {kInt32}, primitive->name());
163 auto scale_value =
164 CheckAndConvertUtils::CheckTensorIntValue("size", shape_value, primitive->name(), shape_abstract->GetType());
165 auto scale = scale_value[0];
166 *new_height = (in_height == -1) ? -1 : (in_height + (in_height - 1) * (scale - 1));
167 *new_width = (in_width == -1) ? -1 : (in_width + (in_width - 1) * (scale - 1));
168 return;
169 }
170 if (size_shape[0] != kSize2 && size_shape[0] != kSize4) {
171 MS_LOG(EXCEPTION) << "For Resize, the inputs[1]'s shape must be (1, ), (2, ) or (4, ), but got " << size_shape;
172 }
173 size_t h_index = size_shape[0] == kSize2 ? 0 : kFormatNCHWIndexH;
174 size_t w_index = size_shape[0] == kSize2 ? 1 : kFormatNCHWIndexW;
175 if (tensor_type == kInt32) {
176 const auto &data_opt = GetArrayValue<int32_t>(shape_abstract);
177 if (!CheckArrayValueValid(data_opt, new_height, new_width)) {
178 return;
179 }
180 const auto &data_array = data_opt.value();
181 *new_height = IntToLong(data_array[h_index]);
182 *new_width = IntToLong(data_array[w_index]);
183 } else if (tensor_type == kFloat32) {
184 auto data_opt = GetArrayValue<float>(shape_abstract);
185 if (!CheckArrayValueValid(data_opt, new_height, new_width)) {
186 return;
187 }
188 const auto &data_array = data_opt.value();
189 *new_height = (in_height == -1) ? -1 : round(data_array[h_index] * LongToFloat(in_height));
190 *new_width = (in_width == -1) ? -1 : round(data_array[w_index] * LongToFloat(in_width));
191 } else {
192 MS_LOG(EXCEPTION) << "For Resize, the inputs[1] datatype " << tensor_type->ToString() << " is not supported.";
193 }
194 }
195
ResizeInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)196 abstract::ShapePtr ResizeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
197 std::vector<int64_t> output_shape(4, -1);
198 auto images_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
199 if (IsDynamicRank(images_shape)) {
200 output_shape = {abstract::TensorShape::kShapeRankAny};
201 return std::make_shared<abstract::TensorShape>(output_shape);
202 }
203 constexpr int64_t image_shape_size = 4;
204 (void)CheckAndConvertUtils::CheckInteger("images dimension", SizeToLong(images_shape.size()), kEqual,
205 image_shape_size, primitive->name());
206
207 output_shape[0] = images_shape[0];
208 output_shape[kFormatNCHWIndexC] = images_shape[kFormatNCHWIndexC];
209
210 int64_t new_height = -1;
211 int64_t new_width = -1;
212 if (primitive->GetAttr(kNewHeight) != nullptr) {
213 new_height = GetValue<int64_t>(primitive->GetAttr(kNewHeight));
214 }
215 if (primitive->GetAttr(kNewWidth) != nullptr) {
216 new_width = GetValue<int64_t>(primitive->GetAttr(kNewWidth));
217 }
218 if (input_args.size() == kResizeInputSize) {
219 auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape())[kShape];
220 if (IsDynamic(size_shape)) {
221 return std::make_shared<abstract::TensorShape>(output_shape);
222 }
223 GetNewHeightAndWidth(primitive, input_args[1], images_shape[kFormatNCHWIndexH], images_shape[kFormatNCHWIndexW],
224 &new_height, &new_width);
225 }
226
227 output_shape[kFormatNCHWIndexH] = new_height;
228 output_shape[kFormatNCHWIndexW] = new_width;
229 return std::make_shared<abstract::TensorShape>(output_shape);
230 }
231
ResizeInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)232 TypePtr ResizeInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
233 auto type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, 0);
234 return type;
235 }
236 } // namespace
237
238 class MIND_API ResizeInfer : public abstract::OpInferBase {
239 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const240 BaseShapePtr InferShape(const PrimitivePtr &primitive,
241 const std::vector<AbstractBasePtr> &input_args) const override {
242 return ResizeInferShape(primitive, input_args);
243 }
244
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const245 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
246 return ResizeInferType(primitive, input_args);
247 }
248 };
249
250 REGISTER_PRIMITIVE_OP_INFER_IMPL(Resize, prim::kPrimResize, ResizeInfer, false);
251 } // namespace ops
252 } // namespace mindspore
253