1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test jvp in graph mode""" 16 17import numpy as np 18import pytest 19import mindspore.nn as nn 20import mindspore.context as context 21from mindspore import Tensor 22from mindspore.nn.grad import Jvp 23 24 25class SingleInputSingleOutputNet(nn.Cell): 26 def construct(self, x): 27 return x**3 28 29 30class SingleInputMultipleOutputNet(nn.Cell): 31 def construct(self, x): 32 return x**3, 2*x 33 34 35class MultipleInputSingleOutputNet(nn.Cell): 36 def construct(self, x, y): 37 return 2*x + 3*y 38 39 40class MultipleInputMultipleOutputNet(nn.Cell): 41 def construct(self, x, y): 42 return 2*x, y**3 43 44 45@pytest.mark.level1 46@pytest.mark.platform_x86_cpu 47@pytest.mark.env_onecard 48@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 49def test_jvp_single_input_single_output_default_v_graph(mode): 50 """ 51 Features: Class Jvp. 52 Description: Test whenther JVP can calculate forward-mode diff correctly. 53 Expectation: No exception. 54 """ 55 context.set_context(mode=mode) 56 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 57 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 58 net = SingleInputSingleOutputNet() 59 expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 60 expect_grad = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 61 primal, grad = Jvp(net)(x, v) 62 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 63 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 64 65 66@pytest.mark.level1 67@pytest.mark.platform_x86_cpu 68@pytest.mark.env_onecard 69@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 70def test_jvp_single_input_single_output_custom_v_graph(mode): 71 """ 72 Features: Class Jvp. 73 Description: Test whenther JVP can calculate forward-mode diff correctly. 74 Expectation: No exception. 75 """ 76 context.set_context(mode=mode) 77 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 78 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 79 net = SingleInputSingleOutputNet() 80 expect_primal = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 81 expect_grad = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 82 primal, grad = Jvp(net)(x, v) 83 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 84 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 85 86 87@pytest.mark.level1 88@pytest.mark.platform_x86_cpu 89@pytest.mark.env_onecard 90@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 91def test_jvp_single_input_multiple_outputs_default_v_graph(mode): 92 """ 93 Features: Class Jvp. 94 Description: Test whenther JVP can calculate forward-mode diff correctly. 95 Expectation: No exception. 96 """ 97 context.set_context(mode=mode) 98 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 99 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 100 net = SingleInputMultipleOutputNet() 101 expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 102 expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 103 expect_grad_0 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 104 expect_grad_1 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 105 primal, grad = Jvp(net)(x, v) 106 assert isinstance(primal, tuple) 107 assert len(primal) == 2 108 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 109 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 110 assert isinstance(grad, tuple) 111 assert len(grad) == 2 112 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 113 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 114 115 116@pytest.mark.level1 117@pytest.mark.platform_x86_cpu 118@pytest.mark.env_onecard 119@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 120def test_jvp_single_input_multiple_outputs_custom_v_graph(mode): 121 """ 122 Features: Class Jvp. 123 Description: Test whenther JVP can calculate forward-mode diff correctly. 124 Expectation: No exception. 125 """ 126 context.set_context(mode=mode) 127 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 128 v = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 129 net = SingleInputMultipleOutputNet() 130 expect_primal_0 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 131 expect_primal_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 132 expect_grad_0 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 133 expect_grad_1 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 134 primal, grad = Jvp(net)(x, v) 135 assert isinstance(primal, tuple) 136 assert len(primal) == 2 137 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 138 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 139 assert isinstance(grad, tuple) 140 assert len(grad) == 2 141 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 142 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 143 144 145@pytest.mark.level1 146@pytest.mark.platform_x86_cpu 147@pytest.mark.env_onecard 148@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 149def test_jvp_multiple_inputs_single_output_default_v_graph(mode): 150 """ 151 Features: Class Jvp. 152 Description: Test whenther JVP can calculate forward-mode diff correctly. 153 Expectation: No exception. 154 """ 155 context.set_context(mode=mode) 156 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 157 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 158 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 159 net = MultipleInputSingleOutputNet() 160 expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) 161 expect_grad = Tensor(np.array([[5, 5], [5, 5]]).astype(np.float32)) 162 primal, grad = Jvp(net)(x, y, (v, v)) 163 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 164 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 165 166 167@pytest.mark.level1 168@pytest.mark.platform_x86_cpu 169@pytest.mark.env_onecard 170@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 171def test_jvp_multiple_inputs_single_output_custom_v_graph(mode): 172 """ 173 Features: Class Jvp. 174 Description: Test whenther JVP can calculate forward-mode diff correctly. 175 Expectation: No exception. 176 """ 177 context.set_context(mode=mode) 178 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 179 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 180 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 181 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 182 net = MultipleInputSingleOutputNet() 183 expect_primal = Tensor(np.array([[5, 10], [15, 20]]).astype(np.float32)) 184 expect_grad = Tensor(np.array([[5, 8], [11, 14]]).astype(np.float32)) 185 primal, grad = Jvp(net)(x, y, (v1, v2)) 186 assert np.allclose(primal.asnumpy(), expect_primal.asnumpy()) 187 assert np.allclose(grad.asnumpy(), expect_grad.asnumpy()) 188 189 190@pytest.mark.level1 191@pytest.mark.platform_x86_cpu 192@pytest.mark.env_onecard 193@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 194def test_jvp_multiple_inputs_multiple_outputs_default_v_graph(mode): 195 """ 196 Features: Class Jvp. 197 Description: Test whenther JVP can calculate forward-mode diff correctly. 198 Expectation: No exception. 199 """ 200 context.set_context(mode=mode) 201 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 202 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 203 v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 204 net = MultipleInputMultipleOutputNet() 205 expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 206 expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 207 expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 208 expect_grad_1 = Tensor(np.array([[3, 12], [27, 48]]).astype(np.float32)) 209 primal, grad = Jvp(net)(x, y, (v, v)) 210 assert isinstance(primal, tuple) 211 assert len(primal) == 2 212 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 213 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 214 assert isinstance(grad, tuple) 215 assert len(grad) == 2 216 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 217 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 218 219 220@pytest.mark.level1 221@pytest.mark.platform_x86_cpu 222@pytest.mark.env_onecard 223@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 224def test_jvp_multiple_inputs_multiple_outputs_custom_v_graph(mode): 225 """ 226 Features: Class Jvp. 227 Description: Test whenther JVP can calculate forward-mode diff correctly. 228 Expectation: No exception. 229 """ 230 context.set_context(mode=mode) 231 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 232 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 233 v1 = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 234 v2 = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 235 net = MultipleInputMultipleOutputNet() 236 expect_primal_0 = Tensor(np.array([[2, 4], [6, 8]]).astype(np.float32)) 237 expect_primal_1 = Tensor(np.array([[1, 8], [27, 64]]).astype(np.float32)) 238 expect_grad_0 = Tensor(np.array([[2, 2], [2, 2]]).astype(np.float32)) 239 expect_grad_1 = Tensor(np.array([[3, 24], [81, 192]]).astype(np.float32)) 240 primal, grad = Jvp(net)(x, y, (v1, v2)) 241 assert isinstance(primal, tuple) 242 assert len(primal) == 2 243 assert np.allclose(primal[0].asnumpy(), expect_primal_0.asnumpy()) 244 assert np.allclose(primal[1].asnumpy(), expect_primal_1.asnumpy()) 245 assert isinstance(grad, tuple) 246 assert len(grad) == 2 247 assert np.allclose(grad[0].asnumpy(), expect_grad_0.asnumpy()) 248 assert np.allclose(grad[1].asnumpy(), expect_grad_1.asnumpy()) 249