1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import pytest 18 19import mindspore.context as context 20import mindspore.nn as nn 21from mindspore import Tensor, Parameter 22from mindspore.ops import operations as P 23import mindspore.common.dtype as mstype 24 25context.set_context(mode=context.GRAPH_MODE, device_target="CPU") 26 27var_np = np.random.rand(3, 3).astype(np.float32) 28accum_np = np.random.rand(3, 3).astype(np.float32) 29 30 31class Net(nn.Cell): 32 def __init__(self): 33 super(Net, self).__init__() 34 self.apply_adagrad = P.ApplyAdagrad() 35 self.var = Parameter(Tensor(var_np), name="var") 36 self.accum = Parameter(Tensor(accum_np), name="accum") 37 38 def construct(self, lr, grad): 39 return self.apply_adagrad(self.var, self.accum, lr, grad) 40 41 42@pytest.mark.level0 43@pytest.mark.platform_x86_cpu 44@pytest.mark.env_onecard 45def test_apply_adagrad(): 46 # numpy op 47 grident_np = np.random.rand(3, 3).astype(np.float32) 48 expect_accum_np = accum_np + grident_np * grident_np 49 expect_var_np = var_np - (0.001 * grident_np * (1 / np.sqrt(expect_accum_np + 1e-6))) 50 51 net = Net() 52 lr = Tensor(0.001, mstype.float32) 53 grad = Tensor(grident_np) 54 out = net(lr, grad) 55 res_var_mindspore = out[0].asnumpy() 56 res_accum_mindspore = out[1].asnumpy() 57 eps = np.array([1e-6 for i in range(9)]).reshape(3, 3) 58 59 assert np.all(expect_var_np - res_var_mindspore < eps) 60 assert np.all(expect_accum_np - res_accum_mindspore < eps) 61