• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test jvp in graph mode"""
16
17import numpy as np
18import pytest
19import mindspore.nn as nn
20import mindspore.context as context
21from mindspore import Tensor
22from mindspore.nn.grad import Jvp
23
24context.set_context(mode=context.GRAPH_MODE)
25
26
27class SingleInputSingleOutputNet(nn.Cell):
28    def construct(self, x):
29        return x**3
30
31
32class SingleInputMultipleOutputNet(nn.Cell):
33    def construct(self, x):
34        return x**3, 2*x
35
36
37class MultipleInputSingleOutputNet(nn.Cell):
38    def construct(self, x, y):
39        return 2*x + 3*y
40
41
42class MultipleInputMultipleOutputNet(nn.Cell):
43    def construct(self, x, y):
44        return 2*x, y**3
45
46
47def test_jvp_single_input_single_output_default_v_graph():
48    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
49    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
50    net = SingleInputSingleOutputNet()
51    Jvp(net)(x, v)
52
53
54def test_jvp_single_input_single_output_custom_v_graph():
55    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
56    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
57    net = SingleInputSingleOutputNet()
58    Jvp(net)(x, v)
59
60
61def test_jvp_single_input_multiple_outputs_default_v_graph():
62    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
63    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
64    net = SingleInputMultipleOutputNet()
65    Jvp(net)(x, v)
66
67
68def test_jvp_single_input_multiple_outputs_custom_v_graph():
69    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
70    v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
71    net = SingleInputMultipleOutputNet()
72    Jvp(net)(x, v)
73
74
75def test_jvp_multiple_inputs_single_output_default_v_graph():
76    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
77    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
78    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
79    net = MultipleInputSingleOutputNet()
80    Jvp(net)(x, y, (v, v))
81
82
83def test_jvp_multiple_inputs_single_output_custom_v_graph():
84    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
85    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
86    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
87    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
88    net = MultipleInputSingleOutputNet()
89    Jvp(net)(x, y, (v1, v2))
90
91
92def test_jvp_multiple_inputs_multiple_outputs_default_v_graph():
93    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
94    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
95    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
96    net = MultipleInputMultipleOutputNet()
97    Jvp(net)(x, y, (v, v))
98
99
100def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph():
101    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
102    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
103    v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
104    v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
105    net = MultipleInputMultipleOutputNet()
106    Jvp(net)(x, y, (v1, v2))
107
108
109def test_jvp_wrong_input_v_graph():
110    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
111    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
112    net = SingleInputSingleOutputNet()
113    with pytest.raises(TypeError):
114        Jvp(net)(x, (v, v))
115
116
117def test_jvp_wrong_input_v_2_graph():
118    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
119    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
120    net = SingleInputSingleOutputNet()
121    with pytest.raises(TypeError):
122        Jvp(net)(x, (v,))
123
124
125def test_jvp_wrong_input_graph():
126    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
127    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
128    net = SingleInputSingleOutputNet()
129    with pytest.raises(TypeError):
130        Jvp(net)(x, x, v)
131
132
133def test_jvp_wrong_input_2_graph():
134    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
135    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
136    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
137    net = MultipleInputSingleOutputNet()
138    with pytest.raises(TypeError):
139        Jvp(net)((x, y), (v, v))
140
141
142def test_jvp_wrong_input_3_graph():
143    x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
144    y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
145    v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
146    net = MultipleInputSingleOutputNet()
147    with pytest.raises(TypeError):
148        Jvp(net)(x, y, v)
149