1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test jvp in pynative mode """ 16 17import numpy as np 18import pytest 19import mindspore.nn as nn 20import mindspore.context as context 21from mindspore import Tensor 22from mindspore.nn.grad import Jvp 23 24context.set_context(mode=context.PYNATIVE_MODE) 25 26class SingleInputSingleOutputNet(nn.Cell): 27 def construct(self, x): 28 return x**3 29 30 31class SingleInputMultipleOutputNet(nn.Cell): 32 def construct(self, x): 33 return x**3, 2*x 34 35 36class MultipleInputSingleOutputNet(nn.Cell): 37 def construct(self, x, y): 38 return 2*x + 3*y 39 40 41class MultipleInputMultipleOutputNet(nn.Cell): 42 def construct(self, x, y): 43 return 2*x, y**3 44 45 46def test_jvp_single_input_single_output_default_v_pynative(): 47 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 48 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 49 net = SingleInputSingleOutputNet() 50 Jvp(net)(x, v) 51 52 53def test_jvp_single_input_single_output_custom_v_pynative(): 54 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 55 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 56 net = SingleInputSingleOutputNet() 57 Jvp(net)(x, v) 58 59 60def test_jvp_single_input_multiple_outputs_default_v_pynative(): 61 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 62 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 63 net = SingleInputMultipleOutputNet() 64 Jvp(net)(x, v) 65 66 67def test_jvp_single_input_multiple_outputs_custom_v_pynative(): 68 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 69 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 70 net = SingleInputMultipleOutputNet() 71 Jvp(net)(x, v) 72 73 74def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative(): 75 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 76 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 77 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 78 net = MultipleInputMultipleOutputNet() 79 Jvp(net)(x, y, (v, v)) 80 81 82def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative(): 83 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 84 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 85 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 86 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 87 net = MultipleInputMultipleOutputNet() 88 Jvp(net)(x, y, (v1, v2)) 89 90 91def test_jvp_multiple_inputs_single_output_default_v_pynative(): 92 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 93 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 94 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 95 net = MultipleInputSingleOutputNet() 96 Jvp(net)(x, y, (v, v)) 97 98 99def test_jvp_multiple_inputs_single_output_custom_v_pynative(): 100 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 101 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 102 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 103 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 104 net = MultipleInputSingleOutputNet() 105 Jvp(net)(x, y, (v1, v2)) 106 107 108def test_jvp_wrong_input_v_pynative(): 109 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 110 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 111 net = SingleInputSingleOutputNet() 112 with pytest.raises(TypeError): 113 Jvp(net)(x, (v, v)) 114 115 116def test_jvp_wrong_input_v_2_pynative(): 117 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 118 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 119 net = SingleInputSingleOutputNet() 120 with pytest.raises(TypeError): 121 Jvp(net)(x, (v,)) 122 123 124def test_jvp_wrong_input_pynative(): 125 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 126 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 127 net = SingleInputSingleOutputNet() 128 with pytest.raises(TypeError): 129 Jvp(net)(x, x, v) 130 131 132def test_jvp_wrong_input_2_pynative(): 133 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 134 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 135 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 136 net = MultipleInputSingleOutputNet() 137 with pytest.raises(TypeError): 138 Jvp(net)((x, y), (v, v)) 139 140 141def test_jvp_wrong_input_3_pynative(): 142 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 143 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 144 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 145 net = MultipleInputSingleOutputNet() 146 with pytest.raises(TypeError): 147 Jvp(net)(x, y, v) 148