1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <algorithm>
19
20 #include "ops/reduce_sum.h"
21 #include "ops/op_utils.h"
22
23 namespace mindspore {
24 namespace ops {
25 namespace {
InferImplReduceFuncCheckAxis(const int64_t & axis,const size_t dim)26 int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {
27 int64_t dim_ = static_cast<int64_t>(dim);
28 if (axis < -dim_ || axis >= dim_) {
29 MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
30 }
31 int64_t ret_axis = axis;
32 if (axis >= -dim_ && axis < 0) {
33 ret_axis += dim_;
34 }
35 return ret_axis;
36 }
37
InferImplReduceFuncCalShape(ShapeVector * shape,const ShapeVector & x_shape,const ValuePtr & axis,bool keep_dims_value)38 void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
39 bool keep_dims_value) {
40 if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
41 auto axis_ptr_list =
42 axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
43 if (!axis_ptr_list.size()) {
44 if (keep_dims_value) (void)shape->insert(shape->end(), x_shape.size(), 1);
45 } else {
46 (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
47 ValuePtrList axis_items = axis_ptr_list;
48 ValuePtrList::iterator it;
49 if (keep_dims_value) {
50 for (it = axis_items.begin(); it != axis_items.end(); ++it) {
51 auto axis_value = GetValue<int64_t>(*it);
52 shape->at(LongToSize(axis_value)) = 1;
53 }
54 } else {
55 std::vector<int64_t> axis_value_list;
56 for (it = axis_items.begin(); it != axis_items.end(); ++it) {
57 auto axis_value = GetValue<int64_t>(*it);
58 auto axis_positive_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
59 axis_value_list.push_back(axis_positive_value);
60 }
61 std::sort(axis_value_list.begin(), axis_value_list.end());
62 std::vector<int64_t>::reverse_iterator it_re;
63 for (it_re = axis_value_list.rbegin(); it_re != axis_value_list.rend(); ++it_re) {
64 (void)shape->erase(shape->begin() + *it_re);
65 }
66 }
67 }
68 } else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
69 (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
70 int64_t axis_value = GetValue<int64_t>(axis);
71 axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
72 if (keep_dims_value) {
73 shape->at(LongToSize(axis_value)) = 1;
74 } else {
75 (void)shape->erase(shape->begin() + axis_value);
76 }
77 } else {
78 MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
79 }
80 return;
81 }
82
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)83 abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
84 MS_EXCEPTION_IF_NULL(primitive);
85 auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReduceSum", input_args, 0);
86 auto input_shape = shape_ptr->shape();
87 auto input_min_shape = shape_ptr->min_shape();
88 auto input_max_shape = shape_ptr->max_shape();
89 auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
90 MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
91 if (!keep_dimis_value_ptr->isa<BoolImm>()) {
92 MS_LOG(EXCEPTION) << "Keep_dims should be Bool.";
93 }
94 bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
95 ShapeVector out_shape = {};
96 ShapeVector out_min_shape = {};
97 ShapeVector out_max_shape = {};
98 int64_t max_v;
99 if (shape_ptr->IsDynamic()) {
100 max_v = *max_element(input_max_shape.begin(), input_max_shape.end());
101 } else {
102 max_v = *max_element(input_shape.begin(), input_shape.end());
103 }
104 const int64_t input_num_ascend = 2;
105 if (input_args.size() == input_num_ascend && input_args[1]->isa<abstract::AbstractTensor>() &&
106 input_args[1]->BuildValue()->isa<AnyValue>()) {
107 auto axis_tensor = input_args[1]->cast<abstract::AbstractTensorPtr>();
108 auto axis_shape = axis_tensor->shape()->shape();
109 if (axis_shape.size() == 1 && axis_shape[0] == -1 && !keep_dims) {
110 out_shape.push_back(-2);
111 for (size_t i = 0; i < input_shape.size(); ++i) {
112 out_min_shape.push_back(1);
113 out_max_shape.push_back(max_v);
114 }
115 } else if (!keep_dims) {
116 for (size_t i = 0; i < input_shape.size() - axis_shape.size(); ++i) {
117 out_shape.push_back(-1);
118 out_min_shape.push_back(1);
119 out_max_shape.push_back(max_v);
120 }
121 } else {
122 for (size_t i = 0; i < input_shape.size(); ++i) {
123 out_shape.push_back(-1);
124 out_min_shape.push_back(1);
125 out_max_shape.push_back(max_v);
126 }
127 }
128 return std::make_shared<abstract::Shape>(out_shape, out_min_shape, out_max_shape);
129 } else {
130 ValuePtr axis_value;
131 ValuePtr axis_ptr;
132 if (input_args.size() == input_num_ascend) {
133 axis_ptr = input_args[1]->BuildValue();
134 } else {
135 axis_ptr = primitive->GetAttr("axis");
136 }
137 MS_EXCEPTION_IF_NULL(axis_ptr);
138 if (axis_ptr->isa<tensor::Tensor>()) {
139 MS_LOG(ERROR) << "Tensor with value";
140 auto axis_type = input_args[1]->BuildType();
141 MS_EXCEPTION_IF_NULL(axis_type);
142 auto axis_type_id = axis_type->cast<TensorTypePtr>();
143 MS_EXCEPTION_IF_NULL(axis_type_id);
144 auto axis_tensor = axis_ptr->cast<tensor::TensorPtr>();
145 MS_EXCEPTION_IF_NULL(axis_tensor);
146 size_t data_size = LongToSize(axis_tensor->DataSize());
147 std::vector<ValuePtr> value_list;
148 if (axis_type_id->element()->type_id() == kNumberTypeInt32) {
149 auto shape_data = reinterpret_cast<int *>(axis_tensor->data_c());
150 MS_EXCEPTION_IF_NULL(shape_data);
151 for (size_t i = 0; i < data_size; i++) {
152 value_list.push_back(MakeValue(static_cast<int64_t>(*shape_data)));
153 ++shape_data;
154 }
155 } else {
156 auto shape_data2 = reinterpret_cast<int64_t *>(axis_tensor->data_c());
157 for (size_t i = 0; i < data_size; i++) {
158 value_list.push_back(MakeValue(static_cast<int64_t>(*shape_data2)));
159 ++shape_data2;
160 }
161 }
162 axis_value = std::make_shared<ValueTuple>(value_list);
163 } else {
164 axis_value = axis_ptr;
165 }
166 InferImplReduceFuncCalShape(&out_shape, input_shape, axis_value, keep_dims);
167
168 if (!input_min_shape.empty() && !input_max_shape.empty()) {
169 ShapeVector shape_min = {};
170 ShapeVector shape_max = {};
171 InferImplReduceFuncCalShape(&shape_min, input_min_shape, axis_value, keep_dims);
172 InferImplReduceFuncCalShape(&shape_max, input_max_shape, axis_value, keep_dims);
173 return std::make_shared<abstract::Shape>(out_shape, shape_min, shape_max);
174 }
175 return std::make_shared<abstract::Shape>(out_shape);
176 }
177 }
178
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)179 TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
180 MS_EXCEPTION_IF_NULL(prim);
181 return CheckAndConvertUtils::CheckTensorTypeValid("x dtype", input_args[0]->BuildType(), common_valid_types,
182 "ReduceSum");
183 }
184 } // namespace
185
ReduceSumInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)186 AbstractBasePtr ReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
187 const std::vector<AbstractBasePtr> &input_args) {
188 const int64_t input_num = 1;
189 (void)CheckAndConvertUtils::CheckInteger("input size", SizeToInt(input_args.size()), kGreaterEqual, input_num,
190 primitive->name());
191 return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
192 }
193 } // namespace ops
194 } // namespace mindspore
195