• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test jvp in pynative mode """
16
17import numpy as np
18import pytest
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Tensor
22from mindspore.nn.grad import Jvp
23
24context.set_context(mode=context.PYNATIVE_MODE)
25
26class SingleInputSingleOutputNet(nn.Cell):
27    def construct(self, x):
28        return x**3
29
30
31class SingleInputMultipleOutputNet(nn.Cell):
32    def construct(self, x):
33        return x**3, 2*x
34
35
36class MultipleInputSingleOutputNet(nn.Cell):
37    def construct(self, x, y):
38        return 2*x + 3*y
39
40
41class MultipleInputMultipleOutputNet(nn.Cell):
42    def construct(self, x, y):
43        return 2*x, y**3
44
45
46def test_jvp_single_input_single_output_default_v_pynative():
47    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
48    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
49    net = SingleInputSingleOutputNet()
50    Jvp(net)(x, v)
51
52
53def test_jvp_single_input_single_output_custom_v_pynative():
54    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
55    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
56    net = SingleInputSingleOutputNet()
57    Jvp(net)(x, v)
58
59
60def test_jvp_single_input_multiple_outputs_default_v_pynative():
61    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
62    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
63    net = SingleInputMultipleOutputNet()
64    Jvp(net)(x, v)
65
66
67def test_jvp_single_input_multiple_outputs_custom_v_pynative():
68    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
69    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
70    net = SingleInputMultipleOutputNet()
71    Jvp(net)(x, v)
72
73
74def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative():
75    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
76    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
77    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
78    net = MultipleInputMultipleOutputNet()
79    Jvp(net)(x, y, (v, v))
80
81
82def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative():
83    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
84    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
85    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
86    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
87    net = MultipleInputMultipleOutputNet()
88    Jvp(net)(x, y, (v1, v2))
89
90
91def test_jvp_multiple_inputs_single_output_default_v_pynative():
92    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
93    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
94    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
95    net = MultipleInputSingleOutputNet()
96    Jvp(net)(x, y, (v, v))
97
98
99def test_jvp_multiple_inputs_single_output_custom_v_pynative():
100    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
101    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
102    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
103    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
104    net = MultipleInputSingleOutputNet()
105    Jvp(net)(x, y, (v1, v2))
106
107
108def test_jvp_wrong_input_v_pynative():
109    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
110    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
111    net = SingleInputSingleOutputNet()
112    with pytest.raises(TypeError):
113        Jvp(net)(x, (v, v))
114
115
116def test_jvp_wrong_input_v_2_pynative():
117    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
118    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
119    net = SingleInputSingleOutputNet()
120    with pytest.raises(TypeError):
121        Jvp(net)(x, (v,))
122
123
124def test_jvp_wrong_input_pynative():
125    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
126    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
127    net = SingleInputSingleOutputNet()
128    with pytest.raises(TypeError):
129        Jvp(net)(x, x, v)
130
131
132def test_jvp_wrong_input_2_pynative():
133    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
134    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
135    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
136    net = MultipleInputSingleOutputNet()
137    with pytest.raises(TypeError):
138        Jvp(net)((x, y), (v, v))
139
140
141def test_jvp_wrong_input_3_pynative():
142    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
143    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
144    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
145    net = MultipleInputSingleOutputNet()
146    with pytest.raises(TypeError):
147        Jvp(net)(x, y, v)
148