• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <string>
17 #include <algorithm>
18 #include <memory>
19 #include <set>
20 #include <vector>
21 #include "ops/batch_norm.h"
22 #include "abstract/primitive_infer_map.h"
23 #include "utils/check_convert_utils.h"
24 
25 namespace mindspore {
26 namespace ops {
Init(const bool is_training,const float epsilon,const float momentum,const Format & format)27 void BatchNorm::Init(const bool is_training, const float epsilon, const float momentum, const Format &format) {
28   set_is_training(is_training);
29   set_epsilon(epsilon);
30   set_format(format);
31   set_momentum(momentum);
32 }
33 
set_is_training(const bool is_training)34 void BatchNorm::set_is_training(const bool is_training) { (void)this->AddAttr(kIsTraining, MakeValue(is_training)); }
35 
set_epsilon(const float epsilon)36 void BatchNorm::set_epsilon(const float epsilon) {
37   CheckAndConvertUtils::CheckInRange<float>(kEpsilon, epsilon, kIncludeBoth, {0.0, 1.0}, this->name());
38   (void)this->AddAttr(kEpsilon, MakeValue(epsilon));
39 }
40 
set_format(const Format & format)41 void BatchNorm::set_format(const Format &format) {
42   int64_t f = format;
43   (void)this->AddAttr(kFormat, MakeValue(f));
44 }
45 
set_momentum(const float momentun)46 void BatchNorm::set_momentum(const float momentun) {
47   CheckAndConvertUtils::CheckInRange<float>(kMomentum, momentun, kIncludeBoth, {0.0, 1.0}, this->name());
48   (void)this->AddAttr(kMomentum, MakeValue(momentun));
49 }
50 
get_momentum() const51 float BatchNorm::get_momentum() const {
52   auto value_ptr = GetAttr(kMomentum);
53   return GetValue<float>(value_ptr);
54 }
55 
get_is_training() const56 bool BatchNorm::get_is_training() const {
57   auto value_ptr = GetAttr(kIsTraining);
58   return GetValue<bool>(value_ptr);
59 }
60 
get_epsilon() const61 float BatchNorm::get_epsilon() const {
62   auto value_ptr = GetAttr(kEpsilon);
63   return GetValue<float>(value_ptr);
64 }
65 
get_format() const66 Format BatchNorm::get_format() const {
67   auto value_ptr = GetAttr(kFormat);
68   return Format(GetValue<int64_t>(value_ptr));
69 }
70 
BatchNormInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)71 AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
72                                const std::vector<AbstractBasePtr> &input_args) {
73   // Infer shape
74   MS_EXCEPTION_IF_NULL(primitive);
75   auto prim_name = primitive->name();
76   const int64_t input_num = 5;
77   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
78 
79   auto input_x = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
80   auto format = Format(GetValue<int64_t>(primitive->GetAttr(kFormat)));
81   if (format == NHWC) {
82     input_x = {input_x[0], input_x[3], input_x[1], input_x[2]};
83   }
84   auto scale = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
85   auto bias = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
86   auto mean = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex3]->BuildShape())[kShape];
87   auto variance = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex4]->BuildShape())[kShape];
88 
89   std::vector<int64_t> input_shape_norm;
90   if (format == NCHW) {
91     input_shape_norm = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
92   } else {
93     input_shape_norm.push_back(input_x[0]);
94     input_shape_norm.push_back(input_x[3]);
95     input_shape_norm.push_back(input_x[1]);
96     input_shape_norm.push_back(input_x[2]);
97   }
98   (void)CheckAndConvertUtils::CheckInteger("scale rank", SizeToLong(scale.size()), kEqual, 1, prim_name);
99   CheckAndConvertUtils::Check("scale shape", scale, kEqual, "bias shape", bias, prim_name, TypeError);
100   CheckAndConvertUtils::Check("scale shape[0]", scale[0], kEqual, "input_x channel", input_shape_norm[1], prim_name,
101                               TypeError);
102 
103   if (!GetValue<bool>(primitive->GetAttr(kIsTraining))) {
104     (void)CheckAndConvertUtils::CheckInteger("mean rank", SizeToLong(mean.size()), kEqual, 1, prim_name);
105     CheckAndConvertUtils::Check("mean shape", mean, kEqual, "variance shape", variance, prim_name, TypeError);
106     CheckAndConvertUtils::Check("mean shape", mean, kEqual, "scale shape", scale, prim_name, TypeError);
107   }
108 
109   // Infer type
110   auto scale_type = input_args[kInputIndex1]->BuildType()->cast<TensorTypePtr>()->element();
111   auto bias_type = input_args[kInputIndex2]->BuildType()->cast<TensorTypePtr>()->element();
112 
113   const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
114   auto input_x_type =
115     CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[kInputIndex0]->BuildType(), valid_types, prim_name);
116   std::map<std::string, TypePtr> args;
117   (void)args.emplace("scale", input_args[kInputIndex1]->BuildType());
118   (void)args.emplace("bias", input_args[kInputIndex2]->BuildType());
119   (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
120   std::map<std::string, TypePtr> args_moving;
121   (void)args_moving.emplace("scale", input_args[kInputIndex2]->BuildType());
122   (void)args_moving.emplace("bias", input_args[kInputIndex3]->BuildType());
123   (void)CheckAndConvertUtils::CheckTensorTypeSame(args_moving, valid_types, prim_name);
124 
125   auto output0 = std::make_shared<abstract::AbstractTensor>(input_x_type, input_x);
126   auto output1 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
127   auto output2 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
128   auto output3 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
129   if (format == NHWC) {
130     output2 = std::make_shared<abstract::AbstractTensor>(scale_type, scale);
131     output3 = std::make_shared<abstract::AbstractTensor>(bias_type, scale);
132     output1 = std::make_shared<abstract::AbstractTensor>(input_x_type, scale);
133   }
134   AbstractBasePtrList output = {output0, output1, output2, output3, output3};
135   return std::make_shared<abstract::AbstractTuple>(output);
136 }
137 REGISTER_PRIMITIVE_C(kNameBatchNorm, BatchNorm);
138 }  // namespace ops
139 }  // namespace mindspore
140