• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test jvp in graph mode"""
16
17import numpy as np
18import pytest
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Tensor
22from mindspore.nn.grad import Jvp
23
24context.set_context(mode=context.GRAPH_MODE)
25
26
27class SingleInputSingleOutputNet(nn.Cell):
28    def construct(self, x):
29        return x**3
30
31
32class SingleInputMultipleOutputNet(nn.Cell):
33    def construct(self, x):
34        return x**3, 2*x
35
36
37class MultipleInputSingleOutputNet(nn.Cell):
38    def construct(self, x, y):
39        return 2*x + 3*y
40
41
42class MultipleInputMultipleOutputNet(nn.Cell):
43    def construct(self, x, y):
44        return 2*x, y**3
45
46
47@pytest.mark.level0
48@pytest.mark.platform_x86_cpu
49@pytest.mark.env_onecard
50def test_jvp_single_input_single_output_default_v_graph():
51    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
52    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
53    net = SingleInputSingleOutputNet()
54    expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
55    expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
56    primal, grad = Jvp(net)(x, v)
57    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
58    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
59
60
61@pytest.mark.level0
62@pytest.mark.platform_x86_cpu
63@pytest.mark.env_onecard
64def test_jvp_single_input_single_output_custom_v_graph():
65    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
66    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
67    net = SingleInputSingleOutputNet()
68    expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
69    expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
70    primal, grad = Jvp(net)(x, v)
71    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
72    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
73
74
75@pytest.mark.level0
76@pytest.mark.platform_x86_cpu
77@pytest.mark.env_onecard
78def test_jvp_single_input_multiple_outputs_default_v_graph():
79    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
80    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
81    net = SingleInputMultipleOutputNet()
82    expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
83    expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
84    expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
85    expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
86    primal, grad = Jvp(net)(x, v)
87    assert isinstance(primal, tuple)
88    assert len(primal) == 2
89    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
90    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
91    assert isinstance(grad, tuple)
92    assert len(grad) == 2
93    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
94    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
95
96
97@pytest.mark.level0
98@pytest.mark.platform_x86_cpu
99@pytest.mark.env_onecard
100def test_jvp_single_input_multiple_outputs_custom_v_graph():
101    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
102    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
103    net = SingleInputMultipleOutputNet()
104    expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
105    expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
106    expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
107    expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
108    primal, grad = Jvp(net)(x, v)
109    assert isinstance(primal, tuple)
110    assert len(primal) == 2
111    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
112    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
113    assert isinstance(grad, tuple)
114    assert len(grad) == 2
115    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
116    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
117
118
119@pytest.mark.level0
120@pytest.mark.platform_x86_cpu
121@pytest.mark.env_onecard
122def test_jvp_multiple_inputs_single_output_default_v_graph():
123    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
124    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
125    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
126    net = MultipleInputSingleOutputNet()
127    expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
128    expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
129    primal, grad = Jvp(net)(x, y, (v, v))
130    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
131    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
132
133
134@pytest.mark.level0
135@pytest.mark.platform_x86_cpu
136@pytest.mark.env_onecard
137def test_jvp_multiple_inputs_single_output_custom_v_graph():
138    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
139    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
140    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
141    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
142    net = MultipleInputSingleOutputNet()
143    expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
144    expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
145    primal, grad = Jvp(net)(x, y, (v1, v2))
146    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
147    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
148
149
150@pytest.mark.level0
151@pytest.mark.platform_x86_cpu
152@pytest.mark.env_onecard
153def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
154    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
155    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
156    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
157    net = MultipleInputMultipleOutputNet()
158    expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
159    expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
160    expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
161    expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
162    primal, grad = Jvp(net)(x, y, (v, v))
163    assert isinstance(primal, tuple)
164    assert len(primal) == 2
165    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
166    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
167    assert isinstance(grad, tuple)
168    assert len(grad) == 2
169    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
170    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
171
172
173@pytest.mark.level0
174@pytest.mark.platform_x86_cpu
175@pytest.mark.env_onecard
176def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
177    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
178    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
179    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
180    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
181    net = MultipleInputMultipleOutputNet()
182    expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
183    expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
184    expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
185    expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
186    primal, grad = Jvp(net)(x, y, (v1, v2))
187    assert isinstance(primal, tuple)
188    assert len(primal) == 2
189    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
190    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
191    assert isinstance(grad, tuple)
192    assert len(grad) == 2
193    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
194    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
195