• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ops/adam.h"
18 #include "ops/op_utils.h"
19 #include "utils/check_convert_utils.h"
20 
21 namespace mindspore {
22 namespace ops {
23 namespace {
AdamInfer(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)24 abstract::AbstractBasePtr AdamInfer(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
25   MS_EXCEPTION_IF_NULL(primitive);
26   auto prim_name = primitive->name();
27   const int64_t input_num = 10;
28   CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, prim_name);
29 
30   // infer shape
31   auto var_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
32   auto m_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->GetShapeTrack())[kShape];
33   auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->GetShapeTrack())[kShape];
34   auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex9]->GetShapeTrack())[kShape];
35   CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "m_shape", m_shape, prim_name);
36   CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "v_shape", v_shape, prim_name);
37   CheckAndConvertUtils::Check("var_shape", var_shape, kEqual, "grad_shape", grad_shape, prim_name);
38 
39   // infer type
40   auto var_type = input_args[kInputIndex0]->BuildType();
41   auto m_type = input_args[kInputIndex1]->BuildType();
42   auto v_type = input_args[kInputIndex2]->BuildType();
43   auto grad_type = input_args[kInputIndex9]->BuildType();
44   auto infer_var_type = CheckAndConvertUtils::CheckTensorTypeValid("var_type", var_type, common_valid_types, prim_name);
45   auto infer_m_type = CheckAndConvertUtils::CheckTensorTypeValid("m_type", m_type, common_valid_types, prim_name);
46   auto infer_v_type = CheckAndConvertUtils::CheckTensorTypeValid("v_type", v_type, common_valid_types, prim_name);
47   (void)CheckAndConvertUtils::CheckTensorTypeValid("grad_type", grad_type, common_valid_types, prim_name);
48   auto output0 = std::make_shared<abstract::AbstractTensor>(infer_var_type, var_shape);
49   auto output1 = std::make_shared<abstract::AbstractTensor>(infer_m_type, m_shape);
50   auto output2 = std::make_shared<abstract::AbstractTensor>(infer_v_type, v_shape);
51   AbstractBasePtrList output = {output0, output1, output2};
52   return std::make_shared<abstract::AbstractTuple>(output);
53 }
54 }  // namespace
Init(const bool use_locking,const bool use_nesterov)55 void Adam::Init(const bool use_locking, const bool use_nesterov) {
56   this->set_use_locking(use_locking);
57   this->set_use_nesterov(use_nesterov);
58 }
59 
set_use_locking(const bool use_locking)60 void Adam::set_use_locking(const bool use_locking) { (void)this->AddAttr(kUseLocking, MakeValue(use_locking)); }
61 
set_use_nesterov(const bool use_nesterov)62 void Adam::set_use_nesterov(const bool use_nesterov) { (void)this->AddAttr(kUseNesterov, MakeValue(use_nesterov)); }
63 
get_use_locking() const64 bool Adam::get_use_locking() const {
65   auto value_ptr = GetAttr(kUseLocking);
66   return GetValue<bool>(value_ptr);
67 }
68 
get_use_nesterov() const69 bool Adam::get_use_nesterov() const {
70   auto value_ptr = GetAttr(kUseNesterov);
71   return GetValue<bool>(value_ptr);
72 }
73 
AdamInfer(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)74 AbstractBasePtr AdamInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
75                           const std::vector<AbstractBasePtr> &input_args) {
76   return std::make_shared<abstract::AbstractTensor>(AdamInfer(primitive, input_args));
77 }
78 REGISTER_PRIMITIVE_C(kNameAdam, Adam);
79 }  // namespace ops
80 }  // namespace mindspore
81