• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import pytest
18
19import mindspore.context as context
20import mindspore.nn as nn
21from mindspore import Tensor, Parameter
22from mindspore.ops import operations as P
23
24
25class AssignAdd(nn.Cell):
26    def __init__(self, value):
27        super(AssignAdd, self).__init__()
28        self.var = Parameter(value, name="var")
29        self.add = P.AssignAdd()
30
31    def construct(self, y):
32        res = self.add(self.var, y)
33        return res
34
35
36@pytest.mark.level0
37@pytest.mark.platform_x86_cpu
38@pytest.mark.env_onecard
39def test_assign_add():
40    expect1 = np.array([[[[0, 2, 4.],
41                          [6, 8, 10.],
42                          [12, 14, 16.]],
43                         [[18, 20, 22.],
44                          [24, 26, 28.],
45                          [30, 32, 34.]],
46                         [[36, 38, 40.],
47                          [42, 44, 46.],
48                          [48, 50, 52.]]]])
49    expect2 = np.array([[[[0, 3, 6],
50                          [9, 12, 15],
51                          [18, 21, 24]],
52                         [[27, 30, 33],
53                          [36, 39, 42],
54                          [45, 48, 51]],
55                         [[54, 57, 60],
56                          [63, 66, 69],
57                          [72, 75, 78]]]])
58
59    x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
60    y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
61
62    context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
63    add = AssignAdd(x2)
64    output1 = add(y2)
65    assert (output1.asnumpy() == expect1).all()
66    add = AssignAdd(output1)
67    output2 = add(y2)
68    assert (output2.asnumpy() == expect2).all()
69