• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test function jacrev in graph mode"""
16import numpy as np
17import pytest
18import mindspore.nn as nn
19import mindspore.context as context
20from mindspore import Tensor
21from mindspore import jit
22from mindspore.ops import jacrev
23
24
25class SingleInputSingleOutputNet(nn.Cell):
26    def construct(self, x):
27        return x ** 3
28
29
30class SingleInputMultipleOutputsNet(nn.Cell):
31    def construct(self, x):
32        return x ** 3, 2 * x
33
34
35class MultipleInputsSingleOutputNet(nn.Cell):
36    def construct(self, x, y, z):
37        return x * y * z
38
39
40class MultipleInputsMultipleOutputsNet(nn.Cell):
41    def construct(self, x, y, z):
42        return x ** 2 + y ** 2 + z ** 2, x * y * z
43
44
45def function(x, y, z):
46    return x ** 2 + y ** 2 + z ** 2, x * y * z
47
48
49def iteration_jac_function(x, y, z):
50    return x ** 2 * y * z
51
52
53@jit
54def jac_wrap_with_jit_function(x, y, z):
55    output = jacrev(function, has_aux=True)(x, y, z)
56    return output
57
58
59@pytest.mark.level1
60@pytest.mark.platform_x86_cpu
61@pytest.mark.env_onecard
62@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
63def test_jac_single_input_single_output_cell_graph(mode):
64    """
65    Features: Function jacrev.
66    Description: Test ops.jacrev with single input and single output net in graph mode.
67    Expectation: No exception.
68    """
69    context.set_context(mode=mode)
70    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
71    net = SingleInputSingleOutputNet()
72    expect_jac = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
73                           [[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
74    jac = jacrev(net)(x)
75    assert np.allclose(jac.asnumpy(), expect_jac)
76
77
78@pytest.mark.level1
79@pytest.mark.platform_x86_cpu
80@pytest.mark.env_onecard
81@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
82def test_jac_single_input_multiple_outputs_cell_graph(mode):
83    """
84    Features: Function jacrev.
85    Description: Test ops.jacrev with single input and multiple outputs net in graph mode.
86    Expectation: No exception.
87    """
88    context.set_context(mode=mode)
89    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
90    net = SingleInputMultipleOutputsNet()
91    expect_jac_0 = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]],
92                             [[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32)
93    expect_jac_1 = np.array([[[[2, 0], [0, 0]], [[0, 2], [0, 0]]],
94                             [[[0, 0], [2, 0]], [[0, 0], [0, 2]]]]).astype(np.float32)
95    jac = jacrev(net)(x)
96    assert np.allclose(jac[0].asnumpy(), expect_jac_0)
97    assert np.allclose(jac[1].asnumpy(), expect_jac_1)
98
99
100@pytest.mark.level1
101@pytest.mark.platform_x86_cpu
102@pytest.mark.env_onecard
103@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
104def test_jac_multiple_inputs_single_output_cell_graph(mode):
105    """
106    Features: Function jacrev.
107    Description: Test ops.jacrev with multiple inputs and single output net in graph mode.
108    Expectation: No exception.
109    """
110    context.set_context(mode=mode)
111    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
112    y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
113    z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
114    net = MultipleInputsSingleOutputNet()
115    expect_jac_0 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
116                             [[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
117    expect_jac_1 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
118                             [[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
119    jac = jacrev(net, grad_position=(1, 2))(x, y, z)
120    assert isinstance(jac, tuple)
121    assert len(jac) == 2
122    assert np.allclose(jac[0].asnumpy(), expect_jac_0)
123    assert np.allclose(jac[1].asnumpy(), expect_jac_1)
124
125
126@pytest.mark.level1
127@pytest.mark.platform_x86_cpu
128@pytest.mark.env_onecard
129@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
130def test_jac_multiple_inputs_multiple_outputs_cell_graph(mode):
131    """
132    Features: Function jacrev.
133    Description: Test ops.jacrev with multiple inputs and multiple outputs net in graph mode.
134    Expectation: No exception.
135    """
136    context.set_context(mode=mode)
137    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
138    y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
139    z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
140    net = MultipleInputsMultipleOutputsNet()
141    expect_jac_0 = np.array([[[[-4, 0], [0, 0]], [[0, 6], [0, 0]]],
142                             [[[0, 0], [-2, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
143    expect_jac_1 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
144                             [[[0, 0], [10, 0]], [[0, 0], [0, -2]]]]).astype(np.float32)
145    expect_jac_2 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]],
146                             [[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32)
147    expect_jac_3 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]],
148                             [[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
149    jac = jacrev(net, grad_position=(1, 2))(x, y, z)
150    assert isinstance(jac, tuple)
151    assert len(jac) == 2
152    assert np.allclose(jac[0][0].asnumpy(), expect_jac_0)
153    assert np.allclose(jac[0][1].asnumpy(), expect_jac_1)
154    assert np.allclose(jac[1][0].asnumpy(), expect_jac_2)
155    assert np.allclose(jac[1][1].asnumpy(), expect_jac_3)
156
157
158@pytest.mark.level1
159@pytest.mark.platform_x86_cpu
160@pytest.mark.env_onecard
161@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
162def test_jac_wrap_with_jit_function_graph(mode):
163    """
164    Features: Function jacrev.
165    Description: Test ops.jacrev warpped with @jit decorated function in graph mode.
166    Expectation: No exception.
167    """
168    context.set_context(mode=mode)
169    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
170    y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
171    z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
172    expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
173                           [[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
174    expect_aux = np.array([[0, 18], [-15, -8]]).astype(np.float32)
175    jac, aux = jac_wrap_with_jit_function(x, y, z)
176    assert np.allclose(jac.asnumpy(), expect_jac)
177    assert np.allclose(aux.asnumpy(), expect_aux)
178
179
180@pytest.mark.level1
181@pytest.mark.platform_x86_cpu
182@pytest.mark.env_onecard
183@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
184def test_jac_with_grad_position_twice_graph(mode):
185    """
186    Features: Function jacrev.
187    Description: Test ops.jacrev with function setting grad_position twice in graph mode.
188    Expectation: No exception.
189    """
190    context.set_context(mode=mode)
191    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
192    y = Tensor(np.array([[1, 3], [5, 7]]).astype(np.float32))
193    z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
194    expect_jac_0 = np.array([[[[1, 0], [0, 0]], [[0, 3], [0, 0]]],
195                             [[[0, 0], [5, 0]], [[0, 0], [0, 7]]]]).astype(np.float32)
196    expect_jac_1 = np.array([[[[1, 0], [0, 0]], [[0, 2], [0, 0]]],
197                             [[[0, 0], [3, 0]], [[0, 0], [0, 4]]]]).astype(np.float32)
198    net = MultipleInputsSingleOutputNet()
199    jac1 = jacrev(net, grad_position=0)(x, y, z)
200    jac2 = jacrev(net, grad_position=(0, 1))(x, y, z)
201
202    assert np.allclose(jac1.asnumpy(), expect_jac_0)
203    assert np.allclose(jac2[1].asnumpy(), expect_jac_1)
204
205
206@pytest.mark.level1
207@pytest.mark.platform_x86_cpu
208@pytest.mark.env_onecard
209@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
210def test_jac_with_has_aux_graph(mode):
211    """
212    Features: Function jacrev.
213    Description: Test ops.jacrev with Cell setting grad_position in graph mode.
214    Expectation: No exception.
215    """
216    context.set_context(mode=mode)
217    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
218    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
219    z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
220    expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
221                           [[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
222    expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
223    net = MultipleInputsMultipleOutputsNet()
224    jac, aux = jacrev(net, grad_position=0, has_aux=True)(x, y, z)
225    assert np.allclose(jac.asnumpy(), expect_jac)
226    assert np.allclose(aux.asnumpy(), expect_aux)
227
228
229@pytest.mark.level1
230@pytest.mark.platform_x86_cpu
231@pytest.mark.env_onecard
232@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
233def test_jac_with_function_has_aux_graph(mode):
234    """
235    Features: Function jacrev.
236    Description: Test ops.jacrev with function setting grad_position in graph mode.
237    Expectation: No exception.
238    """
239    context.set_context(mode=mode)
240    def fn(x, y, z):
241        return x ** 2 + y ** 2 + z ** 2, x * y * z
242
243    def fn2(*args):
244        x = args[0]
245        y = args[1]
246        z = args[2]
247        return fn(x, y, z)
248
249    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
250    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
251    z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
252    expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]],
253                           [[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32)
254    expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32)
255    jac, aux = jacrev(fn2, grad_position=0, has_aux=True)(x, y, z)
256    assert np.allclose(jac.asnumpy(), expect_jac)
257    assert np.allclose(aux.asnumpy(), expect_aux)
258