1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ops/adam.h"
18 #include "ops/op_utils.h"
19 #include "utils/check_convert_utils.h"
20
21 namespace mindspore {
22 namespace ops {
23 namespace {
AdamInfer(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)24 abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
25 MS_EXCEPTION_IF_NULL(primitive);
26 auto prim_name = primitive->name();
27 const int64_t input_num = 10;
28 CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
29
30 // infer shape
31 auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
32 auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
33 auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
34 auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape];
35 CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
36 CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
37 CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
38
39 // infer type
40 auto var_type = input_args[kInputIndex0]->BuildType();
41 auto m_type = input_args[kInputIndex1]->BuildType();
42 auto v_type = input_args[kInputIndex2]->BuildType();
43 auto grad_type = input_args[kInputIndex9]->BuildType();
44 auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
45 auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
46 auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
47 (void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
48 auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
49 auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);
50 auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape);
51 AbstractBasePtrList output = {output0, output1, output2};
52 return std::make_shared<abstract::AbstractTuple>(output);
53 }
54 } // namespace
Init(const bool use_locking,const bool use_nesterov)55 void Adam::Init(const bool use_locking, const bool use_nesterov) {
56 this->set_use_locking(use_locking);
57 this->set_use_nesterov(use_nesterov);
58 }
59
set_use_locking(const bool use_locking)60 void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, MakeValue(use_locking)); }
61
set_use_nesterov(const bool use_nesterov)62 void Adam::set_use_nesterov(const bool use_nesterov) { (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
63
get_use_locking() const64 bool Adam::get_use_locking() const {
65 auto value_ptr = GetAttr(kUseLocking);
66 return GetValue<bool>(value_ptr);
67 }
68
get_use_nesterov() const69 bool Adam::get_use_nesterov() const {
70 auto value_ptr = GetAttr(kUseNesterov);
71 return GetValue<bool>(value_ptr);
72 }
73
AdamInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)74 AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
75 const std::vector<AbstractBasePtr> &input_args) {
76 return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args));
77 }
78 REGISTER_PRIMITIVE_C(kNameAdam, Adam);
79 } // namespace ops
80 } // namespace mindspore
81