• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <vector>
22 #include "ops/fusion/reduce_fusion.h"
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 #include "abstract/primitive_infer_map.h"
26 
27 namespace mindspore {
28 namespace ops {
set_keep_dims(const bool keep_dims)29 void ReduceFusion::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, MakeValue(keep_dims)); }
30 
set_mode(const ReduceMode mode)31 void ReduceFusion::set_mode(const ReduceMode mode) {
32   int64_t swi = mode;
33   (void)this->AddAttr(kMode, MakeValue(swi));
34 }
35 
set_reduce_to_end(const bool reduce_to_end)36 void ReduceFusion::set_reduce_to_end(const bool reduce_to_end) {
37   (void)this->AddAttr(kReduceToEnd, MakeValue(reduce_to_end));
38 }
39 
set_coeff(const float coeff)40 void ReduceFusion::set_coeff(const float coeff) { (void)this->AddAttr(kCoeff, MakeValue(coeff)); }
41 
get_keep_dims() const42 bool ReduceFusion::get_keep_dims() const {
43   auto value_ptr = GetAttr(kKeepDims);
44   MS_EXCEPTION_IF_NULL(value_ptr);
45   return GetValue<bool>(value_ptr);
46 }
47 
get_mode() const48 ReduceMode ReduceFusion::get_mode() const {
49   auto value_ptr = GetAttr(kMode);
50   MS_EXCEPTION_IF_NULL(value_ptr);
51   return ReduceMode(GetValue<int64_t>(value_ptr));
52 }
53 
get_reduce_to_end() const54 bool ReduceFusion::get_reduce_to_end() const {
55   auto value_ptr = GetAttr(kReduceToEnd);
56   MS_EXCEPTION_IF_NULL(value_ptr);
57   return GetValue<bool>(value_ptr);
58 }
59 
get_coeff() const60 float ReduceFusion::get_coeff() const {
61   auto value_ptr = GetAttr(kCoeff);
62   MS_EXCEPTION_IF_NULL(value_ptr);
63   return GetValue<float>(value_ptr);
64 }
65 
Init(const bool keep_dims,const ReduceMode mode,const bool reduce_to_end,const float coeff)66 void ReduceFusion::Init(const bool keep_dims, const ReduceMode mode, const bool reduce_to_end, const float coeff) {
67   this->set_keep_dims(keep_dims);
68   this->set_mode(mode);
69   this->set_reduce_to_end(reduce_to_end);
70   this->set_coeff(coeff);
71 }
72 REGISTER_PRIMITIVE_C(kNameReduceFusion, ReduceFusion);
73 }  // namespace ops
74 }  // namespace mindspore
75