1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/fusion/reduce_fusion.h"
18 #include "abstract/ops/primitive_infer_map.h"
19 #include "mindapi/base/shared_ptr.h"
20 #include "mindapi/ir/value.h"
21 #include "mindapi/src/helper.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "ops/op_name.h"
24 #include "ops/op_utils.h"
25 #include "ops/primitive_c.h"
26 #include "utils/check_convert_utils.h"
27 #include "utils/log_adapter.h"
28
29 namespace mindspore {
30 namespace ops {
31 MIND_API_OPERATOR_IMPL(ReduceFusion, Reduce);
set_keep_dims(const bool keep_dims)32 void ReduceFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
33
set_mode(const ReduceMode mode)34 void ReduceFusion::set_mode(const ReduceMode mode) {
35 int64_t swi = mode;
36 (void)this->AddAttr(kMode, api::MakeValue(swi));
37 }
38
set_reduce_to_end(const bool reduce_to_end)39 void ReduceFusion::set_reduce_to_end(const bool reduce_to_end) {
40 (void)this->AddAttr(kReduceToEnd, api::MakeValue(reduce_to_end));
41 }
42
set_coeff(const float coeff)43 void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, api::MakeValue(coeff)); }
44
get_keep_dims() const45 bool ReduceFusion::get_keep_dims() const {
46 auto value_ptr = GetAttr(kKeepDims);
47 MS_EXCEPTION_IF_NULL(value_ptr);
48 return GetValue<bool>(value_ptr);
49 }
50
get_mode() const51 ReduceMode ReduceFusion::get_mode() const {
52 auto value_ptr = GetAttr(kMode);
53 MS_EXCEPTION_IF_NULL(value_ptr);
54 return ReduceMode(GetValue<int64_t>(value_ptr));
55 }
56
get_reduce_to_end() const57 bool ReduceFusion::get_reduce_to_end() const {
58 auto value_ptr = GetAttr(kReduceToEnd);
59 MS_EXCEPTION_IF_NULL(value_ptr);
60 return GetValue<bool>(value_ptr);
61 }
62
get_coeff() const63 float ReduceFusion::get_coeff() const {
64 auto value_ptr = GetAttr(kCoeff);
65 MS_EXCEPTION_IF_NULL(value_ptr);
66 return GetValue<float>(value_ptr);
67 }
68
Init(const bool keep_dims,const ReduceMode mode,const bool reduce_to_end,const float coeff)69 void ReduceFusion::Init(const bool keep_dims, const ReduceMode mode, const bool reduce_to_end, const float coeff) {
70 this->set_keep_dims(keep_dims);
71 this->set_mode(mode);
72 this->set_reduce_to_end(reduce_to_end);
73 this->set_coeff(coeff);
74 }
75
76 namespace {
ReduceFusionInferShape(const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args)77 abstract::ShapePtr ReduceFusionInferShape(const PrimitivePtr &primitive,
78 const std::vector<abstract::AbstractBasePtr> &input_args) {
79 MS_EXCEPTION_IF_NULL(primitive);
80 CheckAndConvertUtils::CheckArgsType(primitive->name(), input_args, 0, kObjectTypeTensorType);
81 auto x_shape = input_args[0]->GetShape()->GetShapeVector();
82
83 auto keep_dims_value_ptr = primitive->GetAttr(kKeepDims);
84 MS_EXCEPTION_IF_NULL(keep_dims_value_ptr);
85 if (!keep_dims_value_ptr->isa<BoolImm>()) {
86 MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'keep_dims' must be Bool.";
87 }
88 bool keep_dims = GetValue<bool>(keep_dims_value_ptr);
89
90 std::vector<int64_t> axis_value;
91 int64_t axis_shape = 1;
92 bool axis_is_dynamic = CheckAndGetAxisValue(input_args, &axis_value, &axis_shape, primitive);
93 auto reduce_to_end_ptr = primitive->GetAttr(kReduceToEnd);
94 bool reduce_to_end = reduce_to_end_ptr && GetValue<bool>(reduce_to_end_ptr);
95 if (reduce_to_end) {
96 if (axis_value.size() != 1) {
97 MS_EXCEPTION(ValueError) << "For '" << primitive->name()
98 << "', if 'reduce_to_end' is Bool, the axis num should 1";
99 }
100 int64_t begin_axis = axis_value[0];
101 for (int64_t i = begin_axis + 1; i < SizeToLong(x_shape.size()); ++i) {
102 axis_value.push_back(i);
103 }
104 axis_shape = SizeToLong(x_shape.size()) - begin_axis;
105 keep_dims = false;
106 }
107
108 ShapeVector out_shape = {};
109 constexpr int dynamic_rank_value = -2;
110 if (IsDynamicRank(x_shape)) {
111 if (axis_shape == 0 && !keep_dims) {
112 return std::make_shared<abstract::Shape>(out_shape);
113 }
114 out_shape.push_back(dynamic_rank_value);
115 return std::make_shared<abstract::Shape>(out_shape);
116 }
117 if (axis_shape == -1 && !keep_dims) {
118 out_shape.push_back(dynamic_rank_value);
119 return std::make_shared<abstract::Shape>(out_shape);
120 }
121 ReduceFuncCheckAxisInferImpl(primitive, &axis_value, x_shape.size());
122
123 if (axis_is_dynamic) {
124 out_shape = ReduceFuncCalShapeAxisDyn(x_shape, keep_dims);
125 return std::make_shared<abstract::Shape>(out_shape);
126 }
127 out_shape = ReduceFuncCalShapeInferImpl(primitive, x_shape, axis_value, keep_dims);
128 return std::make_shared<abstract::Shape>(out_shape);
129 }
130 } // namespace
131
132 class ReduceFusionInfer : public abstract::OpInferBase {
133 public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const134 BaseShapePtr InferShape(const PrimitivePtr &primitive,
135 const std::vector<AbstractBasePtr> &input_args) const override {
136 const int64_t input_num = 1;
137 MS_EXCEPTION_IF_NULL(primitive);
138 CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
139 primitive->name());
140 return ReduceFusionInferShape(primitive, input_args);
141 }
142
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const143 TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
144 MS_EXCEPTION_IF_NULL(input_args[0]);
145 auto x_type = input_args[0]->GetType();
146 return x_type;
147 }
148 };
149
150 REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceFusion, prim::kPrimReduceFusion, ReduceFusionInfer, false);
151 } // namespace ops
152 } // namespace mindspore
153