1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include "ops/fusion/reduce_fusion.h"
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "abstract/primitive_infer_map.h"
26
27 namespace mindspore {
28 namespace ops {
set_keep_dims(const bool keep_dims)29 void ReduceFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
30
set_mode(const ReduceMode mode)31 void ReduceFusion::set_mode(const ReduceMode mode) {
32 int64_t swi = mode;
33 (void)this->AddAttr(kMode, MakeValue(swi));
34 }
35
set_reduce_to_end(const bool reduce_to_end)36 void ReduceFusion::set_reduce_to_end(const bool reduce_to_end) {
37 (void)this->AddAttr(kReduceToEnd, MakeValue(reduce_to_end));
38 }
39
set_coeff(const float coeff)40 void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, MakeValue(coeff)); }
41
get_keep_dims() const42 bool ReduceFusion::get_keep_dims() const {
43 auto value_ptr = GetAttr(kKeepDims);
44 MS_EXCEPTION_IF_NULL(value_ptr);
45 return GetValue<bool>(value_ptr);
46 }
47
get_mode() const48 ReduceMode ReduceFusion::get_mode() const {
49 auto value_ptr = GetAttr(kMode);
50 MS_EXCEPTION_IF_NULL(value_ptr);
51 return ReduceMode(GetValue<int64_t>(value_ptr));
52 }
53
get_reduce_to_end() const54 bool ReduceFusion::get_reduce_to_end() const {
55 auto value_ptr = GetAttr(kReduceToEnd);
56 MS_EXCEPTION_IF_NULL(value_ptr);
57 return GetValue<bool>(value_ptr);
58 }
59
get_coeff() const60 float ReduceFusion::get_coeff() const {
61 auto value_ptr = GetAttr(kCoeff);
62 MS_EXCEPTION_IF_NULL(value_ptr);
63 return GetValue<float>(value_ptr);
64 }
65
Init(const bool keep_dims,const ReduceMode mode,const bool reduce_to_end,const float coeff)66 void ReduceFusion::Init(const bool keep_dims, const ReduceMode mode, const bool reduce_to_end, const float coeff) {
67 this->set_keep_dims(keep_dims);
68 this->set_mode(mode);
69 this->set_reduce_to_end(reduce_to_end);
70 this->set_coeff(coeff);
71 }
72 REGISTER_PRIMITIVE_C(kNameReduceFusion, ReduceFusion);
73 } // namespace ops
74 } // namespace mindspore
75