1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test jvp in graph mode""" 16 17import numpy as np 18import pytest 19import mindspore.nn as nn 20import mindspore.context as context 21from mindspore import Tensor 22from mindspore.nn.grad import Jvp 23 24context.set_context(mode=context.GRAPH_MODE) 25 26 27class SingleInputSingleOutputNet(nn.Cell): 28 def construct(self, x): 29 return x**3 30 31 32class SingleInputMultipleOutputNet(nn.Cell): 33 def construct(self, x): 34 return x**3, 2*x 35 36 37class MultipleInputSingleOutputNet(nn.Cell): 38 def construct(self, x, y): 39 return 2*x + 3*y 40 41 42class MultipleInputMultipleOutputNet(nn.Cell): 43 def construct(self, x, y): 44 return 2*x, y**3 45 46 47def test_jvp_single_input_single_output_default_v_graph(): 48 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 49 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 50 net = SingleInputSingleOutputNet() 51 Jvp(net)(x, v) 52 53 54def test_jvp_single_input_single_output_custom_v_graph(): 55 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 56 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 57 net = SingleInputSingleOutputNet() 58 Jvp(net)(x, v) 59 60 61def test_jvp_single_input_multiple_outputs_default_v_graph(): 62 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 63 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 64 net = SingleInputMultipleOutputNet() 65 Jvp(net)(x, v) 66 67 68def test_jvp_single_input_multiple_outputs_custom_v_graph(): 69 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 70 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 71 net = SingleInputMultipleOutputNet() 72 Jvp(net)(x, v) 73 74 75def test_jvp_multiple_inputs_single_output_default_v_graph(): 76 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 77 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 78 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 79 net = MultipleInputSingleOutputNet() 80 Jvp(net)(x, y, (v, v)) 81 82 83def test_jvp_multiple_inputs_single_output_custom_v_graph(): 84 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 85 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 86 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 87 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 88 net = MultipleInputSingleOutputNet() 89 Jvp(net)(x, y, (v1, v2)) 90 91 92def test_jvp_multiple_inputs_multiple_outputs_default_v_graph(): 93 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 94 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 95 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 96 net = MultipleInputMultipleOutputNet() 97 Jvp(net)(x, y, (v, v)) 98 99 100def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph(): 101 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 102 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 103 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 104 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 105 net = MultipleInputMultipleOutputNet() 106 Jvp(net)(x, y, (v1, v2)) 107 108 109def test_jvp_wrong_input_v_graph(): 110 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 111 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 112 net = SingleInputSingleOutputNet() 113 with pytest.raises(TypeError): 114 Jvp(net)(x, (v, v)) 115 116 117def test_jvp_wrong_input_v_2_graph(): 118 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 119 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 120 net = SingleInputSingleOutputNet() 121 with pytest.raises(TypeError): 122 Jvp(net)(x, (v,)) 123 124 125def test_jvp_wrong_input_graph(): 126 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 127 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 128 net = SingleInputSingleOutputNet() 129 with pytest.raises(TypeError): 130 Jvp(net)(x, x, v) 131 132 133def test_jvp_wrong_input_2_graph(): 134 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 135 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 136 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 137 net = MultipleInputSingleOutputNet() 138 with pytest.raises(TypeError): 139 Jvp(net)((x, y), (v, v)) 140 141 142def test_jvp_wrong_input_3_graph(): 143 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 144 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 145 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 146 net = MultipleInputSingleOutputNet() 147 with pytest.raises(TypeError): 148 Jvp(net)(x, y, v) 149