• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test jvp in graph mode"""
16
17import numpy as np
18import pytest
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Tensor
22from mindspore.nn.grad import Jvp
23
24
25class SingleInputSingleOutputNet(nn.Cell):
26    def construct(self, x):
27        return x**3
28
29
30class SingleInputMultipleOutputNet(nn.Cell):
31    def construct(self, x):
32        return x**3, 2*x
33
34
35class MultipleInputSingleOutputNet(nn.Cell):
36    def construct(self, x, y):
37        return 2*x + 3*y
38
39
40class MultipleInputMultipleOutputNet(nn.Cell):
41    def construct(self, x, y):
42        return 2*x, y**3
43
44
45@pytest.mark.level1
46@pytest.mark.platform_x86_cpu
47@pytest.mark.env_onecard
48@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
49def test_jvp_single_input_single_output_default_v_graph(mode):
50    """
51    Features: Class Jvp.
52    Description: Test whenther JVP can calculate forward-mode diff correctly.
53    Expectation: No exception.
54    """
55    context.set_context(mode=mode)
56    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
57    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
58    net = SingleInputSingleOutputNet()
59    expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
60    expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
61    primal, grad = Jvp(net)(x, v)
62    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
63    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
64
65
66@pytest.mark.level1
67@pytest.mark.platform_x86_cpu
68@pytest.mark.env_onecard
69@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
70def test_jvp_single_input_single_output_custom_v_graph(mode):
71    """
72    Features: Class Jvp.
73    Description: Test whenther JVP can calculate forward-mode diff correctly.
74    Expectation: No exception.
75    """
76    context.set_context(mode=mode)
77    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
78    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
79    net = SingleInputSingleOutputNet()
80    expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
81    expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
82    primal, grad = Jvp(net)(x, v)
83    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
84    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
85
86
87@pytest.mark.level1
88@pytest.mark.platform_x86_cpu
89@pytest.mark.env_onecard
90@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
91def test_jvp_single_input_multiple_outputs_default_v_graph(mode):
92    """
93    Features: Class Jvp.
94    Description: Test whenther JVP can calculate forward-mode diff correctly.
95    Expectation: No exception.
96    """
97    context.set_context(mode=mode)
98    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
99    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
100    net = SingleInputMultipleOutputNet()
101    expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
102    expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
103    expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
104    expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
105    primal, grad = Jvp(net)(x, v)
106    assert isinstance(primal, tuple)
107    assert len(primal) == 2
108    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
109    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
110    assert isinstance(grad, tuple)
111    assert len(grad) == 2
112    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
113    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
114
115
116@pytest.mark.level1
117@pytest.mark.platform_x86_cpu
118@pytest.mark.env_onecard
119@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
120def test_jvp_single_input_multiple_outputs_custom_v_graph(mode):
121    """
122    Features: Class Jvp.
123    Description: Test whenther JVP can calculate forward-mode diff correctly.
124    Expectation: No exception.
125    """
126    context.set_context(mode=mode)
127    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
128    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
129    net = SingleInputMultipleOutputNet()
130    expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
131    expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
132    expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
133    expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
134    primal, grad = Jvp(net)(x, v)
135    assert isinstance(primal, tuple)
136    assert len(primal) == 2
137    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
138    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
139    assert isinstance(grad, tuple)
140    assert len(grad) == 2
141    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
142    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
143
144
145@pytest.mark.level1
146@pytest.mark.platform_x86_cpu
147@pytest.mark.env_onecard
148@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
149def test_jvp_multiple_inputs_single_output_default_v_graph(mode):
150    """
151    Features: Class Jvp.
152    Description: Test whenther JVP can calculate forward-mode diff correctly.
153    Expectation: No exception.
154    """
155    context.set_context(mode=mode)
156    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
157    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
158    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
159    net = MultipleInputSingleOutputNet()
160    expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
161    expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
162    primal, grad = Jvp(net)(x, y, (v, v))
163    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
164    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
165
166
167@pytest.mark.level1
168@pytest.mark.platform_x86_cpu
169@pytest.mark.env_onecard
170@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
171def test_jvp_multiple_inputs_single_output_custom_v_graph(mode):
172    """
173    Features: Class Jvp.
174    Description: Test whenther JVP can calculate forward-mode diff correctly.
175    Expectation: No exception.
176    """
177    context.set_context(mode=mode)
178    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
179    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
180    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
181    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
182    net = MultipleInputSingleOutputNet()
183    expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
184    expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
185    primal, grad = Jvp(net)(x, y, (v1, v2))
186    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
187    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
188
189
190@pytest.mark.level1
191@pytest.mark.platform_x86_cpu
192@pytest.mark.env_onecard
193@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
194def test_jvp_multiple_inputs_multiple_outputs_default_v_graph(mode):
195    """
196    Features: Class Jvp.
197    Description: Test whenther JVP can calculate forward-mode diff correctly.
198    Expectation: No exception.
199    """
200    context.set_context(mode=mode)
201    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
202    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
203    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
204    net = MultipleInputMultipleOutputNet()
205    expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
206    expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
207    expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
208    expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
209    primal, grad = Jvp(net)(x, y, (v, v))
210    assert isinstance(primal, tuple)
211    assert len(primal) == 2
212    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
213    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
214    assert isinstance(grad, tuple)
215    assert len(grad) == 2
216    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
217    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
218
219
220@pytest.mark.level1
221@pytest.mark.platform_x86_cpu
222@pytest.mark.env_onecard
223@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
224def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph(mode):
225    """
226    Features: Class Jvp.
227    Description: Test whenther JVP can calculate forward-mode diff correctly.
228    Expectation: No exception.
229    """
230    context.set_context(mode=mode)
231    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
232    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
233    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
234    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
235    net = MultipleInputMultipleOutputNet()
236    expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
237    expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
238    expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
239    expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
240    primal, grad = Jvp(net)(x, y, (v1, v2))
241    assert isinstance(primal, tuple)
242    assert len(primal) == 2
243    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
244    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
245    assert isinstance(grad, tuple)
246    assert len(grad) == 2
247    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
248    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
249