1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test jvp in pynative mode """ 16 17import numpy as np 18import pytest 19import mindspore.nn as nn 20import mindspore.context as context 21from mindspore import Tensor 22from mindspore.nn.grad import Jvp 23 24context.set_context(mode=context.PYNATIVE_MODE) 25 26class SingleInputSingleOutputNet(nn.Cell): 27 def construct(self, x): 28 return x**3 29 30 31class SingleInputMultipleOutputNet(nn.Cell): 32 def construct(self, x): 33 return x**3, 2*x 34 35 36class MultipleInputSingleOutputNet(nn.Cell): 37 def construct(self, x, y): 38 return 2*x + 3*y 39 40 41class MultipleInputMultipleOutputNet(nn.Cell): 42 def construct(self, x, y): 43 return 2*x, y**3 44 45 46@pytest.mark.level0 47@pytest.mark.platform_x86_cpu 48@pytest.mark.env_onecard 49def test_jvp_single_input_single_output_default_v_pynative(): 50 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 51 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 52 net = SingleInputSingleOutputNet() 53 expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 54 expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 55 primal, grad = Jvp(net)(x, v) 56 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 57 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 58 59 60@pytest.mark.level0 61@pytest.mark.platform_x86_cpu 62@pytest.mark.env_onecard 63def test_jvp_single_input_single_output_custom_v_pynative(): 64 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 65 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 66 net = SingleInputSingleOutputNet() 67 expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 68 expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 69 primal, grad = Jvp(net)(x, v) 70 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 71 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 72 73 74@pytest.mark.level0 75@pytest.mark.platform_x86_cpu 76@pytest.mark.env_onecard 77def test_jvp_single_input_multiple_outputs_default_v_pynative(): 78 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 79 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 80 net = SingleInputMultipleOutputNet() 81 expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 82 expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 83 expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 84 expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 85 primal, grad = Jvp(net)(x, v) 86 assert isinstance(primal, tuple) 87 assert len(primal) == 2 88 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 89 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 90 assert isinstance(grad, tuple) 91 assert len(grad) == 2 92 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 93 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 94 95 96@pytest.mark.level0 97@pytest.mark.platform_x86_cpu 98@pytest.mark.env_onecard 99def test_jvp_single_input_multiple_outputs_custom_v_pynative(): 100 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 101 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 102 net = SingleInputMultipleOutputNet() 103 expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 104 expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 105 expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 106 expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 107 primal, grad = Jvp(net)(x, v) 108 assert isinstance(primal, tuple) 109 assert len(primal) == 2 110 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 111 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 112 assert isinstance(grad, tuple) 113 assert len(grad) == 2 114 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 115 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 116 117 118@pytest.mark.level0 119@pytest.mark.platform_x86_cpu 120@pytest.mark.env_onecard 121def test_jvp_multiple_inputs_multiple_outputs_default_v_pynative(): 122 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 123 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 124 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 125 net = MultipleInputMultipleOutputNet() 126 expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 127 expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 128 expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 129 expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 130 primal, grad = Jvp(net)(x, y, (v, v)) 131 assert isinstance(primal, tuple) 132 assert len(primal) == 2 133 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 134 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 135 assert isinstance(grad, tuple) 136 assert len(grad) == 2 137 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 138 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 139 140 141@pytest.mark.level0 142@pytest.mark.platform_x86_cpu 143@pytest.mark.env_onecard 144def test_jvp_multiple_inputs_multiple_outputs_custom_v_pynative(): 145 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 146 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 147 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 148 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 149 net = MultipleInputMultipleOutputNet() 150 expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 151 expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 152 expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 153 expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 154 primal, grad = Jvp(net)(x, y, (v1, v2)) 155 assert isinstance(primal, tuple) 156 assert len(primal) == 2 157 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 158 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 159 assert isinstance(grad, tuple) 160 assert len(grad) == 2 161 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 162 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 163 164 165@pytest.mark.level0 166@pytest.mark.platform_x86_cpu 167@pytest.mark.env_onecard 168def test_jvp_multiple_inputs_single_output_default_v_pynative(): 169 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 170 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 171 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 172 net = MultipleInputSingleOutputNet() 173 expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) 174 expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) 175 primal, grad = Jvp(net)(x, y, (v, v)) 176 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 177 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 178 179 180@pytest.mark.level0 181@pytest.mark.platform_x86_cpu 182@pytest.mark.env_onecard 183def test_jvp_multiple_inputs_single_output_custom_v_pynative(): 184 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 185 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 186 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 187 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 188 net = MultipleInputSingleOutputNet() 189 expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) 190 expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) 191 primal, grad = Jvp(net)(x, y, (v1, v2)) 192 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 193 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 194