• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/fusion/reduce_fusion.h"
18 #include "abstract/ops/primitive_infer_map.h"
19 #include "mindapi/base/shared_ptr.h"
20 #include "mindapi/ir/value.h"
21 #include "mindapi/src/helper.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "ops/op_name.h"
24 #include "ops/op_utils.h"
25 #include "ops/primitive_c.h"
26 #include "utils/check_convert_utils.h"
27 #include "utils/log_adapter.h"
28 
29 namespace mindspore {
30 namespace ops {
31 MIND_API_OPERATOR_IMPL(ReduceFusion, Reduce);
set_keep_dims(const bool keep_dims)32 void ReduceFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
33 
set_mode(const ReduceMode mode)34 void ReduceFusion::set_mode(const ReduceMode mode) {
35   int64_t swi = mode;
36   (void)this->AddAttr(kMode, api::MakeValue(swi));
37 }
38 
set_reduce_to_end(const bool reduce_to_end)39 void ReduceFusion::set_reduce_to_end(const bool reduce_to_end) {
40   (void)this->AddAttr(kReduceToEnd, api::MakeValue(reduce_to_end));
41 }
42 
set_coeff(const float coeff)43 void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, api::MakeValue(coeff)); }
44 
get_keep_dims() const45 bool ReduceFusion::get_keep_dims() const {
46   auto value_ptr = GetAttr(kKeepDims);
47   MS_EXCEPTION_IF_NULL(value_ptr);
48   return GetValue<bool>(value_ptr);
49 }
50 
get_mode() const51 ReduceMode ReduceFusion::get_mode() const {
52   auto value_ptr = GetAttr(kMode);
53   MS_EXCEPTION_IF_NULL(value_ptr);
54   return ReduceMode(GetValue<int64_t>(value_ptr));
55 }
56 
get_reduce_to_end() const57 bool ReduceFusion::get_reduce_to_end() const {
58   auto value_ptr = GetAttr(kReduceToEnd);
59   MS_EXCEPTION_IF_NULL(value_ptr);
60   return GetValue<bool>(value_ptr);
61 }
62 
get_coeff() const63 float ReduceFusion::get_coeff() const {
64   auto value_ptr = GetAttr(kCoeff);
65   MS_EXCEPTION_IF_NULL(value_ptr);
66   return GetValue<float>(value_ptr);
67 }
68 
Init(const bool keep_dims,const ReduceMode mode,const bool reduce_to_end,const float coeff)69 void ReduceFusion::Init(const bool keep_dims, const ReduceMode mode, const bool reduce_to_end, const float coeff) {
70   this->set_keep_dims(keep_dims);
71   this->set_mode(mode);
72   this->set_reduce_to_end(reduce_to_end);
73   this->set_coeff(coeff);
74 }
75 
76 namespace {
ReduceFusionInferShape(const PrimitivePtr & primitive,const std::vector<abstract::AbstractBasePtr> & input_args)77 abstract::ShapePtr ReduceFusionInferShape(const PrimitivePtr &primitive,
78                                           const std::vector<abstract::AbstractBasePtr> &input_args) {
79   MS_EXCEPTION_IF_NULL(primitive);
80   CheckAndConvertUtils::CheckArgsType(primitive->name(), input_args, 0, kObjectTypeTensorType);
81   auto x_shape = input_args[0]->GetShape()->GetShapeVector();
82 
83   auto keep_dims_value_ptr = primitive->GetAttr(kKeepDims);
84   MS_EXCEPTION_IF_NULL(keep_dims_value_ptr);
85   if (!keep_dims_value_ptr->isa<BoolImm>()) {
86     MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'keep_dims' must be Bool.";
87   }
88   bool keep_dims = GetValue<bool>(keep_dims_value_ptr);
89 
90   std::vector<int64_t> axis_value;
91   int64_t axis_shape = 1;
92   bool axis_is_dynamic = CheckAndGetAxisValue(input_args, &axis_value, &axis_shape, primitive);
93   auto reduce_to_end_ptr = primitive->GetAttr(kReduceToEnd);
94   bool reduce_to_end = reduce_to_end_ptr && GetValue<bool>(reduce_to_end_ptr);
95   if (reduce_to_end) {
96     if (axis_value.size() != 1) {
97       MS_EXCEPTION(ValueError) << "For '" << primitive->name()
98                                << "', if 'reduce_to_end' is Bool, the axis num should 1";
99     }
100     int64_t begin_axis = axis_value[0];
101     for (int64_t i = begin_axis + 1; i < SizeToLong(x_shape.size()); ++i) {
102       axis_value.push_back(i);
103     }
104     axis_shape = SizeToLong(x_shape.size()) - begin_axis;
105     keep_dims = false;
106   }
107 
108   ShapeVector out_shape = {};
109   constexpr int dynamic_rank_value = -2;
110   if (IsDynamicRank(x_shape)) {
111     if (axis_shape == 0 && !keep_dims) {
112       return std::make_shared<abstract::Shape>(out_shape);
113     }
114     out_shape.push_back(dynamic_rank_value);
115     return std::make_shared<abstract::Shape>(out_shape);
116   }
117   if (axis_shape == -1 && !keep_dims) {
118     out_shape.push_back(dynamic_rank_value);
119     return std::make_shared<abstract::Shape>(out_shape);
120   }
121   ReduceFuncCheckAxisInferImpl(primitive, &axis_value, x_shape.size());
122 
123   if (axis_is_dynamic) {
124     out_shape = ReduceFuncCalShapeAxisDyn(x_shape, keep_dims);
125     return std::make_shared<abstract::Shape>(out_shape);
126   }
127   out_shape = ReduceFuncCalShapeInferImpl(primitive, x_shape, axis_value, keep_dims);
128   return std::make_shared<abstract::Shape>(out_shape);
129 }
130 }  // namespace
131 
132 class ReduceFusionInfer : public abstract::OpInferBase {
133  public:
InferShape(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const134   BaseShapePtr InferShape(const PrimitivePtr &primitive,
135                           const std::vector<AbstractBasePtr> &input_args) const override {
136     const int64_t input_num = 1;
137     MS_EXCEPTION_IF_NULL(primitive);
138     CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num,
139                                        primitive->name());
140     return ReduceFusionInferShape(primitive, input_args);
141   }
142 
InferType(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args) const143   TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
144     MS_EXCEPTION_IF_NULL(input_args[0]);
145     auto x_type = input_args[0]->GetType();
146     return x_type;
147   }
148 };
149 
150 REGISTER_PRIMITIVE_OP_INFER_IMPL(ReduceFusion, prim::kPrimReduceFusion, ReduceFusionInfer, false);
151 }  // namespace ops
152 }  // namespace mindspore
153