1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_OPS_BATCH_NORMAL_H_ 18 #define MINDSPORE_CORE_OPS_BATCH_NORMAL_H_ 19 #include <map> 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include "ops/op_utils.h" 24 #include "ops/primitive_c.h" 25 #include "abstract/abstract_value.h" 26 27 namespace mindspore { 28 namespace ops { 29 constexpr auto kNameBatchNorm = "BatchNorm"; 30 /// \brief Batch Normalization for input data and updated parameters. 31 /// Refer to Python API @ref mindspore.ops.BatchNorm for more details. 32 class MS_CORE_API BatchNorm : public PrimitiveC { 33 public: 34 /// \brief Constructor. BatchNorm()35 BatchNorm() : PrimitiveC(kNameBatchNorm) { 36 InitIOName({"x", "scale", "offset", "mean", "variance"}, 37 {"y", "batch_mean", "batch_variance", "reserve_space_1", "reserve_space_2"}); 38 } 39 /// \brief Destructor. 40 ~BatchNorm() = default; 41 MS_DECLARE_PARENT(BatchNorm, PrimitiveC); 42 /// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.BatchNorm for the inputs. 43 void Init(const bool is_training = false, const float epsilon = 1e-5, const float momentun = 0.1, 44 const Format &format = NCHW); 45 /// \brief Set is_training. 46 void set_is_training(const bool is_training); 47 /// \brief Set epsilon. 48 void set_epsilon(const float epsilon); 49 /// \brief Set format. 50 void set_format(const Format &format); 51 /// \brief Set momentum. 52 void set_momentum(const float momentum); 53 /// \brief Get is_training. 54 /// 55 /// \return is_training. 56 bool get_is_training() const; 57 /// \brief Get epsilon. 58 /// 59 /// \return epsilon. 60 float get_epsilon() const; 61 /// \brief Get format. 62 /// 63 /// \return format. 64 Format get_format() const; 65 /// \brief Get momentum. 66 /// 67 /// \return momentum. 68 float get_momentum() const; 69 }; 70 71 AbstractBasePtr BatchNormInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, 72 const std::vector<AbstractBasePtr> &input_args); 73 using PrimBatchNormPtr = std::shared_ptr<BatchNorm>; 74 75 } // namespace ops 76 } // namespace mindspore 77 78 #endif // MINDSPORE_CORE_OPS_BatchNorm_H_ 79