1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <vector>
18 #include <memory>
19 #include <set>
20 #include <map>
21 #include <string>
22 #include "ops/apply_momentum.h"
23 #include "ops/op_utils.h"
24 #include "utils/check_convert_utils.h"
25
26 namespace mindspore {
27 namespace ops {
Init(const bool use_nesterov,const bool use_locking,const float gradient_scale)28 void ApplyMomentum::Init(const bool use_nesterov, const bool use_locking, const float gradient_scale) {
29 this->set_use_nesterov(use_nesterov);
30 this->set_use_locking(use_locking);
31 this->set_gradient_scale(gradient_scale);
32 }
33
set_use_nesterov(const bool use_nesterov)34 void ApplyMomentum::set_use_nesterov(const bool use_nesterov) {
35 (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov));
36 }
37
set_use_locking(const bool use_locking)38 void ApplyMomentum::set_use_locking(const bool use_locking) {
39 (void)this->AddAttr(kUseLocking, MakeValue(use_locking));
40 }
41
set_gradient_scale(const float gradient_scale)42 void ApplyMomentum::set_gradient_scale(const float gradient_scale) {
43 (void)this->AddAttr(kGradientScale, MakeValue(gradient_scale));
44 }
45
get_use_nesterov() const46 bool ApplyMomentum::get_use_nesterov() const {
47 auto value_ptr = GetAttr(kUseNesterov);
48 return GetValue<bool>(value_ptr);
49 }
50
get_use_locking() const51 bool ApplyMomentum::get_use_locking() const {
52 auto value_ptr = GetAttr(kUseLocking);
53 return GetValue<bool>(value_ptr);
54 }
55
get_gradient_scale() const56 float ApplyMomentum::get_gradient_scale() const {
57 auto value_ptr = GetAttr(kGradientScale);
58 return GetValue<float>(value_ptr);
59 }
ApplyMomentumInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)60 AbstractBasePtr ApplyMomentumInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
61 const std::vector<AbstractBasePtr> &input_args) {
62 MS_EXCEPTION_IF_NULL(primitive);
63 auto prim_name = primitive->name();
64 const int64_t input_num = 5;
65 (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
66
67 for (const auto &item : input_args) {
68 MS_EXCEPTION_IF_NULL(item);
69 }
70 // Infer shape
71 auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
72
73 // Infer type
74 auto v_tensor_type = input_args[kInputIndex0]->BuildType();
75 auto a_tensor_type = input_args[kInputIndex1]->BuildType();
76 auto l_type = input_args[kInputIndex2]->BuildType();
77 auto g_type = input_args[kInputIndex3]->BuildType();
78 auto m_type = input_args[kInputIndex4]->BuildType();
79 const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
80 (void)CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_tensor_type, valid_types, prim_name);
81 (void)CheckAndConvertUtils::CheckTensorTypeValid("a_type", a_tensor_type, valid_types, prim_name);
82 std::map<std::string, TypePtr> args;
83 (void)args.insert(std::make_pair("l_type", l_type));
84 (void)args.insert(std::make_pair("g_type", g_type));
85 (void)args.insert(std::make_pair("m_type", m_type));
86 CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args, valid_types, prim_name);
87 auto g_type_tensor = g_type->cast<TensorTypePtr>();
88 auto element = g_type_tensor->element();
89 return std::make_shared<abstract::AbstractTensor>(element, v_shape);
90 }
91 REGISTER_PRIMITIVE_C(kNameApplyMomentum, ApplyMomentum);
92 } // namespace ops
93 } // namespace mindspore
94