1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test jvp in graph mode""" 16 17import numpy as np 18import pytest 19import mindspore.nn as nn 20import mindspore.context as context 21from mindspore import Tensor 22from mindspore.nn.grad import Jvp 23 24context.set_context(mode=context.GRAPH_MODE) 25 26 27class SingleInputSingleOutputNet(nn.Cell): 28 def construct(self, x): 29 return x**3 30 31 32class SingleInputMultipleOutputNet(nn.Cell): 33 def construct(self, x): 34 return x**3, 2*x 35 36 37class MultipleInputSingleOutputNet(nn.Cell): 38 def construct(self, x, y): 39 return 2*x + 3*y 40 41 42class MultipleInputMultipleOutputNet(nn.Cell): 43 def construct(self, x, y): 44 return 2*x, y**3 45 46 47@pytest.mark.level0 48@pytest.mark.platform_x86_cpu 49@pytest.mark.env_onecard 50def test_jvp_single_input_single_output_default_v_graph(): 51 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 52 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 53 net = SingleInputSingleOutputNet() 54 expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 55 expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 56 primal, grad = Jvp(net)(x, v) 57 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 58 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 59 60 61@pytest.mark.level0 62@pytest.mark.platform_x86_cpu 63@pytest.mark.env_onecard 64def test_jvp_single_input_single_output_custom_v_graph(): 65 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 66 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 67 net = SingleInputSingleOutputNet() 68 expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 69 expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 70 primal, grad = Jvp(net)(x, v) 71 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 72 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 73 74 75@pytest.mark.level0 76@pytest.mark.platform_x86_cpu 77@pytest.mark.env_onecard 78def test_jvp_single_input_multiple_outputs_default_v_graph(): 79 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 80 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 81 net = SingleInputMultipleOutputNet() 82 expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 83 expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 84 expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 85 expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 86 primal, grad = Jvp(net)(x, v) 87 assert isinstance(primal, tuple) 88 assert len(primal) == 2 89 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 90 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 91 assert isinstance(grad, tuple) 92 assert len(grad) == 2 93 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 94 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 95 96 97@pytest.mark.level0 98@pytest.mark.platform_x86_cpu 99@pytest.mark.env_onecard 100def test_jvp_single_input_multiple_outputs_custom_v_graph(): 101 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 102 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 103 net = SingleInputMultipleOutputNet() 104 expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 105 expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 106 expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 107 expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 108 primal, grad = Jvp(net)(x, v) 109 assert isinstance(primal, tuple) 110 assert len(primal) == 2 111 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 112 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 113 assert isinstance(grad, tuple) 114 assert len(grad) == 2 115 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 116 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 117 118 119@pytest.mark.level0 120@pytest.mark.platform_x86_cpu 121@pytest.mark.env_onecard 122def test_jvp_multiple_inputs_single_output_default_v_graph(): 123 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 124 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 125 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 126 net = MultipleInputSingleOutputNet() 127 expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) 128 expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) 129 primal, grad = Jvp(net)(x, y, (v, v)) 130 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 131 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 132 133 134@pytest.mark.level0 135@pytest.mark.platform_x86_cpu 136@pytest.mark.env_onecard 137def test_jvp_multiple_inputs_single_output_custom_v_graph(): 138 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 139 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 140 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 141 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 142 net = MultipleInputSingleOutputNet() 143 expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) 144 expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) 145 primal, grad = Jvp(net)(x, y, (v1, v2)) 146 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 147 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 148 149 150@pytest.mark.level0 151@pytest.mark.platform_x86_cpu 152@pytest.mark.env_onecard 153def test_jvp_multiple_inputs_multiple_outputs_default_v_graph(): 154 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 155 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 156 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 157 net = MultipleInputMultipleOutputNet() 158 expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 159 expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 160 expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 161 expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 162 primal, grad = Jvp(net)(x, y, (v, v)) 163 assert isinstance(primal, tuple) 164 assert len(primal) == 2 165 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 166 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 167 assert isinstance(grad, tuple) 168 assert len(grad) == 2 169 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 170 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 171 172 173@pytest.mark.level0 174@pytest.mark.platform_x86_cpu 175@pytest.mark.env_onecard 176def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph(): 177 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 178 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 179 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 180 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 181 net = MultipleInputMultipleOutputNet() 182 expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 183 expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 184 expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 185 expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 186 primal, grad = Jvp(net)(x, y, (v1, v2)) 187 assert isinstance(primal, tuple) 188 assert len(primal) == 2 189 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 190 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 191 assert isinstance(grad, tuple) 192 assert len(grad) == 2 193 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 194 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 195