• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <memory>
18 #include <algorithm>
19 
20 #include "ops/reduce_sum.h"
21 #include "ops/op_utils.h"
22 
23 namespace mindspore {
24 namespace ops {
25 namespace {
InferImplReduceFuncCheckAxis(const int64_t & axis,const size_t dim)26 int64_t InferImplReduceFuncCheckAxis(const int64_t &axis, const size_t dim) {
27   int64_t dim_ = static_cast<int64_t>(dim);
28   if (axis < -dim_ || axis >= dim_) {
29     MS_LOG(EXCEPTION) << "axis should be in [" << -dim_ << ", " << dim_ << "). But got axis = " << axis;
30   }
31   int64_t ret_axis = axis;
32   if (axis >= -dim_ && axis < 0) {
33     ret_axis += dim_;
34   }
35   return ret_axis;
36 }
37 
InferImplReduceFuncCalShape(ShapeVector * shape,const ShapeVector & x_shape,const ValuePtr & axis,bool keep_dims_value)38 void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape, const ValuePtr &axis,
39                                  bool keep_dims_value) {
40   if (axis->isa<ValueTuple>() || axis->isa<ValueList>()) {
41     auto axis_ptr_list =
42       axis->isa<ValueTuple>() ? axis->cast<ValueTuplePtr>()->value() : axis->cast<ValueListPtr>()->value();
43     if (!axis_ptr_list.size()) {
44       if (keep_dims_value) (void)shape->insert(shape->end(), x_shape.size(), 1);
45     } else {
46       (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
47       ValuePtrList axis_items = axis_ptr_list;
48       ValuePtrList::iterator it;
49       if (keep_dims_value) {
50         for (it = axis_items.begin(); it != axis_items.end(); ++it) {
51           auto axis_value = GetValue<int64_t>(*it);
52           shape->at(LongToSize(axis_value)) = 1;
53         }
54       } else {
55         std::vector<int64_t> axis_value_list;
56         for (it = axis_items.begin(); it != axis_items.end(); ++it) {
57           auto axis_value = GetValue<int64_t>(*it);
58           auto axis_positive_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
59           axis_value_list.push_back(axis_positive_value);
60         }
61         std::sort(axis_value_list.begin(), axis_value_list.end());
62         std::vector<int64_t>::reverse_iterator it_re;
63         for (it_re = axis_value_list.rbegin(); it_re != axis_value_list.rend(); ++it_re) {
64           (void)shape->erase(shape->begin() + *it_re);
65         }
66       }
67     }
68   } else if (axis->isa<Int32Imm>() || axis->isa<Int64Imm>()) {
69     (void)shape->insert(shape->end(), x_shape.begin(), x_shape.end());
70     int64_t axis_value = GetValue<int64_t>(axis);
71     axis_value = InferImplReduceFuncCheckAxis(axis_value, x_shape.size());
72     if (keep_dims_value) {
73       shape->at(LongToSize(axis_value)) = 1;
74     } else {
75       (void)shape->erase(shape->begin() + axis_value);
76     }
77   } else {
78     MS_LOG(EXCEPTION) << "Axis should be one of types: [int/tuple/list].";
79   }
80   return;
81 }
82 
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)83 abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
84   MS_EXCEPTION_IF_NULL(primitive);
85   auto shape_ptr = CheckAndConvertUtils::GetTensorInputShape("ReduceSum", input_args, 0);
86   auto input_shape = shape_ptr->shape();
87   auto input_min_shape = shape_ptr->min_shape();
88   auto input_max_shape = shape_ptr->max_shape();
89   auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
90   MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
91   if (!keep_dimis_value_ptr->isa<BoolImm>()) {
92     MS_LOG(EXCEPTION) << "Keep_dims should be Bool.";
93   }
94   bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
95   ShapeVector out_shape = {};
96   ShapeVector out_min_shape = {};
97   ShapeVector out_max_shape = {};
98   int64_t max_v;
99   if (shape_ptr->IsDynamic()) {
100     max_v = *max_element(input_max_shape.begin(), input_max_shape.end());
101   } else {
102     max_v = *max_element(input_shape.begin(), input_shape.end());
103   }
104   const int64_t input_num_ascend = 2;
105   if (input_args.size() == input_num_ascend && input_args[1]->isa<abstract::AbstractTensor>() &&
106       input_args[1]->BuildValue()->isa<AnyValue>()) {
107     auto axis_tensor = input_args[1]->cast<abstract::AbstractTensorPtr>();
108     auto axis_shape = axis_tensor->shape()->shape();
109     if (axis_shape.size() == 1 && axis_shape[0] == -1 && !keep_dims) {
110       out_shape.push_back(-2);
111       for (size_t i = 0; i < input_shape.size(); ++i) {
112         out_min_shape.push_back(1);
113         out_max_shape.push_back(max_v);
114       }
115     } else if (!keep_dims) {
116       for (size_t i = 0; i < input_shape.size() - axis_shape.size(); ++i) {
117         out_shape.push_back(-1);
118         out_min_shape.push_back(1);
119         out_max_shape.push_back(max_v);
120       }
121     } else {
122       for (size_t i = 0; i < input_shape.size(); ++i) {
123         out_shape.push_back(-1);
124         out_min_shape.push_back(1);
125         out_max_shape.push_back(max_v);
126       }
127     }
128     return std::make_shared<abstract::Shape>(out_shape, out_min_shape, out_max_shape);
129   } else {
130     ValuePtr axis_value;
131     ValuePtr axis_ptr;
132     if (input_args.size() == input_num_ascend) {
133       axis_ptr = input_args[1]->BuildValue();
134     } else {
135       axis_ptr = primitive->GetAttr("axis");
136     }
137     MS_EXCEPTION_IF_NULL(axis_ptr);
138     if (axis_ptr->isa<tensor::Tensor>()) {
139       MS_LOG(ERROR) << "Tensor with value";
140       auto axis_type = input_args[1]->BuildType();
141       MS_EXCEPTION_IF_NULL(axis_type);
142       auto axis_type_id = axis_type->cast<TensorTypePtr>();
143       MS_EXCEPTION_IF_NULL(axis_type_id);
144       auto axis_tensor = axis_ptr->cast<tensor::TensorPtr>();
145       MS_EXCEPTION_IF_NULL(axis_tensor);
146       size_t data_size = LongToSize(axis_tensor->DataSize());
147       std::vector<ValuePtr> value_list;
148       if (axis_type_id->element()->type_id() == kNumberTypeInt32) {
149         auto shape_data = reinterpret_cast<int *>(axis_tensor->data_c());
150         MS_EXCEPTION_IF_NULL(shape_data);
151         for (size_t i = 0; i < data_size; i++) {
152           value_list.push_back(MakeValue(static_cast<int64_t>(*shape_data)));
153           ++shape_data;
154         }
155       } else {
156         auto shape_data2 = reinterpret_cast<int64_t *>(axis_tensor->data_c());
157         for (size_t i = 0; i < data_size; i++) {
158           value_list.push_back(MakeValue(static_cast<int64_t>(*shape_data2)));
159           ++shape_data2;
160         }
161       }
162       axis_value = std::make_shared<ValueTuple>(value_list);
163     } else {
164       axis_value = axis_ptr;
165     }
166     InferImplReduceFuncCalShape(&out_shape, input_shape, axis_value, keep_dims);
167 
168     if (!input_min_shape.empty() && !input_max_shape.empty()) {
169       ShapeVector shape_min = {};
170       ShapeVector shape_max = {};
171       InferImplReduceFuncCalShape(&shape_min, input_min_shape, axis_value, keep_dims);
172       InferImplReduceFuncCalShape(&shape_max, input_max_shape, axis_value, keep_dims);
173       return std::make_shared<abstract::Shape>(out_shape, shape_min, shape_max);
174     }
175     return std::make_shared<abstract::Shape>(out_shape);
176   }
177 }
178 
InferType(const PrimitivePtr & prim,const std::vector<AbstractBasePtr> & input_args)179 TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
180   MS_EXCEPTION_IF_NULL(prim);
181   return CheckAndConvertUtils::CheckTensorTypeValid("x dtype", input_args[0]->BuildType(), common_valid_types,
182                                                     "ReduceSum");
183 }
184 }  // namespace
185 
ReduceSumInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)186 AbstractBasePtr ReduceSumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
187                                const std::vector<AbstractBasePtr> &input_args) {
188   const int64_t input_num = 1;
189   (void)CheckAndConvertUtils::CheckInteger("input size", SizeToInt(input_args.size()), kGreaterEqual, input_num,
190                                            primitive->name());
191   return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
192 }
193 }  // namespace ops
194 }  // namespace mindspore
195