• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <vector>
18 #include <memory>
19 #include <set>
20 #include <map>
21 #include <string>
22 #include "ops/apply_momentum.h"
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25 
26 namespace mindspore {
27 namespace ops {
Init(const bool use_nesterov,const bool use_locking,const float gradient_scale)28 void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const float gradient_scale) {
29   this->set_use_nesterov(use_nesterov);
30   this->set_use_locking(use_locking);
31   this->set_gradient_scale(gradient_scale);
32 }
33 
set_use_nesterov(const bool use_nesterov)34 void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
35   (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov));
36 }
37 
set_use_locking(const bool use_locking)38 void ApplyMomentum::set_use_locking(const bool use_locking) {
39   (void)this->AddAttr(kUseLocking, MakeValue(use_locking));
40 }
41 
set_gradient_scale(const float gradient_scale)42 void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
43   (void)this->AddAttr(kGradientScale, MakeValue(gradient_scale));
44 }
45 
get_use_nesterov() const46 bool ApplyMomentum::get_use_nesterov() const {
47   auto value_ptr = GetAttr(kUseNesterov);
48   return GetValue<bool>(value_ptr);
49 }
50 
get_use_locking() const51 bool ApplyMomentum::get_use_locking() const {
52   auto value_ptr = GetAttr(kUseLocking);
53   return GetValue<bool>(value_ptr);
54 }
55 
get_gradient_scale() const56 float ApplyMomentum::get_gradient_scale() const {
57   auto value_ptr = GetAttr(kGradientScale);
58   return GetValue<float>(value_ptr);
59 }
ApplyMomentumInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)60 AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
61                                    const std::vector<AbstractBasePtr> &input_args) {
62   MS_EXCEPTION_IF_NULL(primitive);
63   auto prim_name = primitive->name();
64   const int64_t input_num = 5;
65   (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
66 
67   for (const auto &item : input_args) {
68     MS_EXCEPTION_IF_NULL(item);
69   }
70   // Infer shape
71   auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
72 
73   // Infer type
74   auto v_tensor_type = input_args[kInputIndex0]->BuildType();
75   auto a_tensor_type = input_args[kInputIndex1]->BuildType();
76   auto l_type = input_args[kInputIndex2]->BuildType();
77   auto g_type = input_args[kInputIndex3]->BuildType();
78   auto m_type = input_args[kInputIndex4]->BuildType();
79   const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
80   (void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
81   (void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
82   std::map<std::string, TypePtr> args;
83   (void)args.insert(std::make_pair("l_type", l_type));
84   (void)args.insert(std::make_pair("g_type", g_type));
85   (void)args.insert(std::make_pair("m_type", m_type));
86   CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name);
87   auto g_type_tensor = g_type->cast<TensorTypePtr>();
88   auto element = g_type_tensor->element();
89   return std::make_shared<abstract::AbstractTensor>(element, v_shape);
90 }
91 REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);
92 }  // namespace ops
93 }  // namespace mindspore
94