1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <string>
17 #include <algorithm>
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include "ops/batch_norm.h"
22 #include "abstract/primitive_infer_map.h"
23 #include "utils/check_convert_utils.h"
24
25 namespace mindspore {
26 namespace ops {
Init(const bool is_training,const float epsilon,const float momentum,const Format & format)27 void BatchNorm::Init(const bool is_training, const float epsilon, const float momentum, const Format &format) {
28 set_is_training(is_training);
29 set_epsilon(epsilon);
30 set_format(format);
31 set_momentum(momentum);
32 }
33
set_is_training(const bool is_training)34 void BatchNorm::set_is_training(const bool is_training) { (void)this->AddAttr(kIsTraining, MakeValue(is_training)); }
35
set_epsilon(const float epsilon)36 void BatchNorm::set_epsilon(const float epsilon) {
37 CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
38 (void)this->AddAttr(kEpsilon, MakeValue(epsilon));
39 }
40
set_format(const Format & format)41 void BatchNorm::set_format(const Format &format) {
42 int64_t f = format;
43 (void)this->AddAttr(kFormat, MakeValue(f));
44 }
45
set_momentum(const float momentun)46 void BatchNorm::set_momentum(const float momentun) {
47 CheckAndConvertUtils::CheckInRange<float>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name());
48 (void)this->AddAttr(kMomentum, MakeValue(momentun));
49 }
50
get_momentum() const51 float BatchNorm::get_momentum() const {
52 auto value_ptr = GetAttr(kMomentum);
53 return GetValue<float>(value_ptr);
54 }
55
get_is_training() const56 bool BatchNorm::get_is_training() const {
57 auto value_ptr = GetAttr(kIsTraining);
58 return GetValue<bool>(value_ptr);
59 }
60
get_epsilon() const61 float BatchNorm::get_epsilon() const {
62 auto value_ptr = GetAttr(kEpsilon);
63 return GetValue<float>(value_ptr);
64 }
65
get_format() const66 Format BatchNorm::get_format() const {
67 auto value_ptr = GetAttr(kFormat);
68 return Format(GetValue<int64_t>(value_ptr));
69 }
70
BatchNormInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)71 AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
72 const std::vector<AbstractBasePtr> &input_args) {
73 // Infer shape
74 MS_EXCEPTION_IF_NULL(primitive);
75 auto prim_name = primitive->name();
76 const int64_t input_num = 5;
77 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
78
79 auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
80 auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
81 if (format == NHWC) {
82 input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
83 }
84 auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
85 auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
86 auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
87 auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
88
89 std::vector<int64_t> input_shape_norm;
90 if (format == NCHW) {
91 input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
92 } else {
93 input_shape_norm.push_back(input_x[0]);
94 input_shape_norm.push_back(input_x[3]);
95 input_shape_norm.push_back(input_x[1]);
96 input_shape_norm.push_back(input_x[2]);
97 }
98 (void)CheckAndConvertUtils::CheckInteger("scale rank", SizeToLong(scale.size()), kEqual, 1, prim_name);
99 CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
100 CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
101 TypeError);
102
103 if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) {
104 (void)CheckAndConvertUtils::CheckInteger("mean rank", SizeToLong(mean.size()), kEqual, 1, prim_name);
105 CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
106 CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
107 }
108
109 // Infer type
110 auto scale_type = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
111 auto bias_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>()->element();
112
113 const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
114 auto input_x_type =
115 CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name);
116 std::map<std::string, TypePtr> args;
117 (void)args.emplace("scale", input_args[kInputIndex1]->BuildType());
118 (void)args.emplace("bias", input_args[kInputIndex2]->BuildType());
119 (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
120 std::map<std::string, TypePtr> args_moving;
121 (void)args_moving.emplace("scale", input_args[kInputIndex2]->BuildType());
122 (void)args_moving.emplace("bias", input_args[kInputIndex3]->BuildType());
123 (void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
124
125 auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
126 auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
127 auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
128 auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
129 if (format == NHWC) {
130 output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
131 output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
132 output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
133 }
134 AbstractBasePtrList output = {output0, output1, output2, output3, output3};
135 return std::make_shared<abstract::AbstractTuple>(output);
136 }
137 REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm);
138 } // namespace ops
139 } // namespace mindspore
140