• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/resize.h"
18 #include "abstract/ops/primitive_infer_map.h"
19 #include "mindapi/base/shared_ptr.h"
20 #include "mindapi/ir/value.h"
21 #include "mindapi/src/helper.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "ops/op_name.h"
24 #include "ops/primitive_c.h"
25 #include "utils/check_convert_utils.h"
26 #include "utils/log_adapter.h"
27 #include "ops/op_utils.h"
28 
29 namespace mindspore {
30 namespace ops {
31 MIND_API_OPERATOR_IMPL(Resize, BaseOperator);
Init(const Format format,const ResizeMethod method,const int64_t new_height,const int64_t new_width,const bool preserve_aspect_ratio,const CoordinateTransformMode coordinate_transform_mode,const float cubic_coeff,const int64_t exclude_outside,const float extrapolation_value,const NearestMode nearest_mode)32 void Resize::Init(const Format format, const ResizeMethod method, const int64_t new_height, const int64_t new_width,
33                   const bool preserve_aspect_ratio, const CoordinateTransformMode coordinate_transform_mode,
34                   const float cubic_coeff, const int64_t exclude_outside, const float extrapolation_value,
35                   const NearestMode nearest_mode) {
36   this->set_format(format);
37   this->set_method(method);
38   this->set_new_height(new_height);
39   this->set_new_width(new_width);
40   this->set_preserve_aspect_ratio(preserve_aspect_ratio);
41   this->set_coordinate_transform_mode(coordinate_transform_mode);
42   this->set_cubic_coeff(cubic_coeff);
43   this->set_exclude_outside(exclude_outside);
44   this->set_extrapolation_value(extrapolation_value);
45   this->set_nearest_mode(nearest_mode);
46 }
set_format(const Format format)47 void Resize::set_format(const Format format) {
48   int64_t swi = format;
49   (void)this->AddAttr(kFormat, api::MakeValue(swi));
50 }
51 
set_method(const ResizeMethod method)52 void Resize::set_method(const ResizeMethod method) {
53   auto swi = static_cast<int64_t>(method);
54   (void)this->AddAttr(kMethod, api::MakeValue(swi));
55 }
56 
set_new_height(const int64_t new_height)57 void Resize::set_new_height(const int64_t new_height) { (void)this->AddAttr(kNewHeight, api::MakeValue(new_height)); }
58 
set_new_width(const int64_t new_width)59 void Resize::set_new_width(const int64_t new_width) { (void)this->AddAttr(kNewWidth, api::MakeValue(new_width)); }
60 
set_preserve_aspect_ratio(const bool preserve_aspect_ratio)61 void Resize::set_preserve_aspect_ratio(const bool preserve_aspect_ratio) {
62   (void)this->AddAttr(kPreserveAspectRatio, api::MakeValue(preserve_aspect_ratio));
63 }
64 
set_coordinate_transform_mode(const CoordinateTransformMode coordinate_transform_mode)65 void Resize::set_coordinate_transform_mode(const CoordinateTransformMode coordinate_transform_mode) {
66   int64_t swi = coordinate_transform_mode;
67   (void)this->AddAttr(kCoordinateTransformMode, api::MakeValue(swi));
68 }
69 
set_cubic_coeff(const float cubic_coeff)70 void Resize::set_cubic_coeff(const float cubic_coeff) { (void)this->AddAttr(kCubicCoeff, api::MakeValue(cubic_coeff)); }
71 
set_exclude_outside(const int64_t exclude_outside)72 void Resize::set_exclude_outside(const int64_t exclude_outside) {
73   (void)this->AddAttr(kExcludeOutside, api::MakeValue(exclude_outside));
74 }
75 
set_extrapolation_value(const float extrapolation_value)76 void Resize::set_extrapolation_value(const float extrapolation_value) {
77   (void)this->AddAttr(kExtrapolationValue, api::MakeValue(extrapolation_value));
78 }
79 
set_nearest_mode(const NearestMode nearest_mode)80 void Resize::set_nearest_mode(const NearestMode nearest_mode) {
81   int64_t swi = static_cast<int64_t>(nearest_mode);
82   (void)this->AddAttr(kNearestMode, api::MakeValue(swi));
83 }
84 
get_format() const85 Format Resize::get_format() const {
86   auto value_ptr = GetAttr(kFormat);
87   return Format(GetValue<int64_t>(value_ptr));
88 }
89 
get_method() const90 ResizeMethod Resize::get_method() const {
91   auto value_ptr = GetAttr(kMethod);
92   return ResizeMethod(GetValue<int64_t>(value_ptr));
93 }
94 
get_new_height() const95 int64_t Resize::get_new_height() const {
96   auto value_ptr = GetAttr(kNewHeight);
97   return GetValue<int64_t>(value_ptr);
98 }
99 
get_new_width() const100 int64_t Resize::get_new_width() const {
101   auto value_ptr = GetAttr(kNewWidth);
102   return GetValue<int64_t>(value_ptr);
103 }
get_preserve_aspect_ratio() const104 bool Resize::get_preserve_aspect_ratio() const {
105   auto value_ptr = GetAttr(kPreserveAspectRatio);
106   return GetValue<bool>(value_ptr);
107 }
get_coordinate_transform_mode() const108 CoordinateTransformMode Resize::get_coordinate_transform_mode() const {
109   auto value_ptr = GetAttr(kCoordinateTransformMode);
110   return CoordinateTransformMode(GetValue<int64_t>(value_ptr));
111 }
112 
get_cubic_coeff() const113 float Resize::get_cubic_coeff() const {
114   auto value_ptr = GetAttr(kCubicCoeff);
115   return GetValue<float>(value_ptr);
116 }
117 
get_exclude_outside() const118 int64_t Resize::get_exclude_outside() const {
119   auto value_ptr = GetAttr(kExcludeOutside);
120   return GetValue<int64_t>(value_ptr);
121 }
122 
get_extrapolation_value() const123 float Resize::get_extrapolation_value() const {
124   auto value_ptr = GetAttr(kExtrapolationValue);
125   return GetValue<float>(value_ptr);
126 }
127 
get_nearest_mode() const128 NearestMode Resize::get_nearest_mode() const {
129   auto value_ptr = GetAttr(kNearestMode);
130   return NearestMode(GetValue<int64_t>(value_ptr));
131 }
132 
133 namespace {
134 template <typename T>
CheckArrayValueValid(const std::optional<ArrayValue<T>> & input_opt,int64_t * new_height,int64_t * new_width)135 bool CheckArrayValueValid(const std::optional<ArrayValue<T>> &input_opt, int64_t *new_height, int64_t *new_width) {
136   if (input_opt.has_value()) {
137     const auto &input_array = input_opt.value();
138     if (input_array.size() > 0) {
139       return true;
140     }
141   }
142   *new_height = -1;
143   *new_width = -1;
144   return false;
145 }
146 
147 constexpr size_t kResizeInputSize = 2;
GetNewHeightAndWidth(const PrimitivePtr & primitive,const AbstractBasePtr & shape_abstract,const int64_t & in_height,const int64_t & in_width,int64_t * new_height,int64_t * new_width)148 void GetNewHeightAndWidth(const PrimitivePtr &primitive, const AbstractBasePtr &shape_abstract,
149                           const int64_t &in_height, const int64_t &in_width, int64_t *new_height, int64_t *new_width) {
150   if (!CheckAndConvertUtils::IsTensor(shape_abstract)) {
151     MS_LOG(EXCEPTION) << "For Resize, the inputs[1] must be a tensor, but got: " << shape_abstract->ToString() << ".";
152   }
153   auto shape_value = shape_abstract->GetValue();
154   auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_abstract->GetShape())[kShape];
155   (void)CheckAndConvertUtils::CheckInteger("size dimension", SizeToLong(size_shape.size()), kEqual, 1,
156                                            primitive->name());
157   auto input_dtype = shape_abstract->GetType()->cast<TensorTypePtr>();
158   MS_EXCEPTION_IF_NULL(input_dtype);
159   auto tensor_type = input_dtype->element();
160   if (size_shape[0] == 1) {
161     // zoom factor
162     (void)CheckAndConvertUtils::CheckTypeValid("size", tensor_type, {kInt32}, primitive->name());
163     auto scale_value =
164       CheckAndConvertUtils::CheckTensorIntValue("size", shape_value, primitive->name(), shape_abstract->GetType());
165     auto scale = scale_value[0];
166     *new_height = (in_height == -1) ? -1 : (in_height + (in_height - 1) * (scale - 1));
167     *new_width = (in_width == -1) ? -1 : (in_width + (in_width - 1) * (scale - 1));
168     return;
169   }
170   if (size_shape[0] != kSize2 && size_shape[0] != kSize4) {
171     MS_LOG(EXCEPTION) << "For Resize, the inputs[1]'s shape must be (1, ), (2, ) or (4, ), but got " << size_shape;
172   }
173   size_t h_index = size_shape[0] == kSize2 ? 0 : kFormatNCHWIndexH;
174   size_t w_index = size_shape[0] == kSize2 ? 1 : kFormatNCHWIndexW;
175   if (tensor_type == kInt32) {
176     const auto &data_opt = GetArrayValue<int32_t>(shape_abstract);
177     if (!CheckArrayValueValid(data_opt, new_height, new_width)) {
178       return;
179     }
180     const auto &data_array = data_opt.value();
181     *new_height = IntToLong(data_array[h_index]);
182     *new_width = IntToLong(data_array[w_index]);
183   } else if (tensor_type == kFloat32) {
184     auto data_opt = GetArrayValue<float>(shape_abstract);
185     if (!CheckArrayValueValid(data_opt, new_height, new_width)) {
186       return;
187     }
188     const auto &data_array = data_opt.value();
189     *new_height = (in_height == -1) ? -1 : round(data_array[h_index] * LongToFloat(in_height));
190     *new_width = (in_width == -1) ? -1 : round(data_array[w_index] * LongToFloat(in_width));
191   } else {
192     MS_LOG(EXCEPTION) << "For Resize, the inputs[1] datatype " << tensor_type->ToString() << " is not supported.";
193   }
194 }
195 
ResizeInferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)196 abstract::ShapePtr ResizeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
197   std::vector<int64_t> output_shape(4, -1);
198   auto images_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShape())[kShape];
199   if (IsDynamicRank(images_shape)) {
200     output_shape = {abstract::TensorShape::kShapeRankAny};
201     return std::make_shared<abstract::TensorShape>(output_shape);
202   }
203   constexpr int64_t image_shape_size = 4;
204   (void)CheckAndConvertUtils::CheckInteger("images dimension", SizeToLong(images_shape.size()), kEqual,
205                                            image_shape_size, primitive->name());
206 
207   output_shape[0] = images_shape[0];
208   output_shape[kFormatNCHWIndexC] = images_shape[kFormatNCHWIndexC];
209 
210   int64_t new_height = -1;
211   int64_t new_width = -1;
212   if (primitive->GetAttr(kNewHeight) != nullptr) {
213     new_height = GetValue<int64_t>(primitive->GetAttr(kNewHeight));
214   }
215   if (primitive->GetAttr(kNewWidth) != nullptr) {
216     new_width = GetValue<int64_t>(primitive->GetAttr(kNewWidth));
217   }
218   if (input_args.size() == kResizeInputSize) {
219     auto size_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShape())[kShape];
220     if (IsDynamic(size_shape)) {
221       return std::make_shared<abstract::TensorShape>(output_shape);
222     }
223     GetNewHeightAndWidth(primitive, input_args[1], images_shape[kFormatNCHWIndexH], images_shape[kFormatNCHWIndexW],
224                          &new_height, &new_width);
225   }
226 
227   output_shape[kFormatNCHWIndexH] = new_height;
228   output_shape[kFormatNCHWIndexW] = new_width;
229   return std::make_shared<abstract::TensorShape>(output_shape);
230 }
231 
ResizeInferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)232 TypePtr ResizeInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
233   auto type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, 0);
234   return type;
235 }
236 }  // namespace
237 
238 class MIND_API ResizeInfer : public abstract::OpInferBase {
239  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const240   BaseShapePtr InferShape(const PrimitivePtr &primitive,
241                           const std::vector<AbstractBasePtr> &input_args) const override {
242     return ResizeInferShape(primitive, input_args);
243   }
244 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const245   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
246     return ResizeInferType(primitive, input_args);
247   }
248 };
249 
250 REGISTER_PRIMITIVE_OP_INFER_IMPL(Resize, prim::kPrimResize, ResizeInfer, false);
251 }  // namespace ops
252 }  // namespace mindspore
253