• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test jvp in pynative mode """
16
17import numpy as np
18import pytest
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Tensor
22from mindspore.nn.grad import Jvp
23
24context.set_context(mode=context.PYNATIVE_MODE)
25
26class SingleInputSingleOutputNet(nn.Cell):
27    def construct(self, x):
28        return x**3
29
30
31class SingleInputMultipleOutputNet(nn.Cell):
32    def construct(self, x):
33        return x**3, 2*x
34
35
36class MultipleInputSingleOutputNet(nn.Cell):
37    def construct(self, x, y):
38        return 2*x + 3*y
39
40
41class MultipleInputMultipleOutputNet(nn.Cell):
42    def construct(self, x, y):
43        return 2*x, y**3
44
45
46@pytest.mark.level0
47@pytest.mark.platform_x86_cpu
48@pytest.mark.env_onecard
49def test_jvp_single_input_single_output_default_v_pynative():
50    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
51    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
52    net = SingleInputSingleOutputNet()
53    expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
54    expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
55    primal, grad = Jvp(net)(x, v)
56    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
57    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
58
59
60@pytest.mark.level0
61@pytest.mark.platform_x86_cpu
62@pytest.mark.env_onecard
63def test_jvp_single_input_single_output_custom_v_pynative():
64    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
65    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
66    net = SingleInputSingleOutputNet()
67    expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
68    expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
69    primal, grad = Jvp(net)(x, v)
70    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
71    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
72
73
74@pytest.mark.level0
75@pytest.mark.platform_x86_cpu
76@pytest.mark.env_onecard
77def test_jvp_single_input_multiple_outputs_default_v_pynative():
78    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
79    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
80    net = SingleInputMultipleOutputNet()
81    expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
82    expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
83    expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
84    expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
85    primal, grad = Jvp(net)(x, v)
86    assert isinstance(primal, tuple)
87    assert len(primal) == 2
88    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
89    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
90    assert isinstance(grad, tuple)
91    assert len(grad) == 2
92    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
93    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
94
95
96@pytest.mark.level0
97@pytest.mark.platform_x86_cpu
98@pytest.mark.env_onecard
99def test_jvp_single_input_multiple_outputs_custom_v_pynative():
100    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
101    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
102    net = SingleInputMultipleOutputNet()
103    expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
104    expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
105    expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
106    expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
107    primal, grad = Jvp(net)(x, v)
108    assert isinstance(primal, tuple)
109    assert len(primal) == 2
110    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
111    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
112    assert isinstance(grad, tuple)
113    assert len(grad) == 2
114    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
115    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
116
117
118@pytest.mark.level0
119@pytest.mark.platform_x86_cpu
120@pytest.mark.env_onecard
121def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
122    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
123    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
124    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
125    net = MultipleInputMultipleOutputNet()
126    expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
127    expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
128    expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
129    expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32))
130    primal, grad = Jvp(net)(x, y, (v, v))
131    assert isinstance(primal, tuple)
132    assert len(primal) == 2
133    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
134    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
135    assert isinstance(grad, tuple)
136    assert len(grad) == 2
137    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
138    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
139
140
141@pytest.mark.level0
142@pytest.mark.platform_x86_cpu
143@pytest.mark.env_onecard
144def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
145    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
146    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
147    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
148    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
149    net = MultipleInputMultipleOutputNet()
150    expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32))
151    expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32))
152    expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32))
153    expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32))
154    primal, grad = Jvp(net)(x, y, (v1, v2))
155    assert isinstance(primal, tuple)
156    assert len(primal) == 2
157    assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy())
158    assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy())
159    assert isinstance(grad, tuple)
160    assert len(grad) == 2
161    assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy())
162    assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy())
163
164
165@pytest.mark.level0
166@pytest.mark.platform_x86_cpu
167@pytest.mark.env_onecard
168def test_jvp_multiple_inputs_single_output_default_v_pynative():
169    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
170    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
171    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
172    net = MultipleInputSingleOutputNet()
173    expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
174    expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32))
175    primal, grad = Jvp(net)(x, y, (v, v))
176    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
177    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
178
179
180@pytest.mark.level0
181@pytest.mark.platform_x86_cpu
182@pytest.mark.env_onecard
183def test_jvp_multiple_inputs_single_output_custom_v_pynative():
184    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
185    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
186    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
187    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
188    net = MultipleInputSingleOutputNet()
189    expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32))
190    expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32))
191    primal, grad = Jvp(net)(x, y, (v1, v2))
192    assert np.allclose(primal.asnumpy(), expect_primal.asnumpy())
193    assert np.allclose(grad.asnumpy(), expect_grad.asnumpy())
194