1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test function jacrev in graph mode""" 16import numpy as np 17import pytest 18import mindspore.nn as nn 19import mindspore.context as context 20from mindspore import Tensor 21from mindspore import jit 22from mindspore.ops import jacrev 23 24 25class SingleInputSingleOutputNet(nn.Cell): 26 def construct(self, x): 27 return x ** 3 28 29 30class SingleInputMultipleOutputsNet(nn.Cell): 31 def construct(self, x): 32 return x ** 3, 2 * x 33 34 35class MultipleInputsSingleOutputNet(nn.Cell): 36 def construct(self, x, y, z): 37 return x * y * z 38 39 40class MultipleInputsMultipleOutputsNet(nn.Cell): 41 def construct(self, x, y, z): 42 return x ** 2 + y ** 2 + z ** 2, x * y * z 43 44 45def function(x, y, z): 46 return x ** 2 + y ** 2 + z ** 2, x * y * z 47 48 49def iteration_jac_function(x, y, z): 50 return x ** 2 * y * z 51 52 53@jit 54def jac_wrap_with_jit_function(x, y, z): 55 output = jacrev(function, has_aux=True)(x, y, z) 56 return output 57 58 59@pytest.mark.level1 60@pytest.mark.platform_x86_cpu 61@pytest.mark.env_onecard 62@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 63def test_jac_single_input_single_output_cell_graph(mode): 64 """ 65 Features: Function jacrev. 66 Description: Test ops.jacrev with single input and single output net in graph mode. 67 Expectation: No exception. 68 """ 69 context.set_context(mode=mode) 70 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 71 net = SingleInputSingleOutputNet() 72 expect_jac = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]], 73 [[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32) 74 jac = jacrev(net)(x) 75 assert np.allclose(jac.asnumpy(), expect_jac) 76 77 78@pytest.mark.level1 79@pytest.mark.platform_x86_cpu 80@pytest.mark.env_onecard 81@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 82def test_jac_single_input_multiple_outputs_cell_graph(mode): 83 """ 84 Features: Function jacrev. 85 Description: Test ops.jacrev with single input and multiple outputs net in graph mode. 86 Expectation: No exception. 87 """ 88 context.set_context(mode=mode) 89 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 90 net = SingleInputMultipleOutputsNet() 91 expect_jac_0 = np.array([[[[3, 0], [0, 0]], [[0, 12], [0, 0]]], 92 [[[0, 0], [27, 0]], [[0, 0], [0, 48]]]]).astype(np.float32) 93 expect_jac_1 = np.array([[[[2, 0], [0, 0]], [[0, 2], [0, 0]]], 94 [[[0, 0], [2, 0]], [[0, 0], [0, 2]]]]).astype(np.float32) 95 jac = jacrev(net)(x) 96 assert np.allclose(jac[0].asnumpy(), expect_jac_0) 97 assert np.allclose(jac[1].asnumpy(), expect_jac_1) 98 99 100@pytest.mark.level1 101@pytest.mark.platform_x86_cpu 102@pytest.mark.env_onecard 103@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 104def test_jac_multiple_inputs_single_output_cell_graph(mode): 105 """ 106 Features: Function jacrev. 107 Description: Test ops.jacrev with multiple inputs and single output net in graph mode. 108 Expectation: No exception. 109 """ 110 context.set_context(mode=mode) 111 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 112 y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) 113 z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) 114 net = MultipleInputsSingleOutputNet() 115 expect_jac_0 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]], 116 [[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32) 117 expect_jac_1 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]], 118 [[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32) 119 jac = jacrev(net, grad_position=(1, 2))(x, y, z) 120 assert isinstance(jac, tuple) 121 assert len(jac) == 2 122 assert np.allclose(jac[0].asnumpy(), expect_jac_0) 123 assert np.allclose(jac[1].asnumpy(), expect_jac_1) 124 125 126@pytest.mark.level1 127@pytest.mark.platform_x86_cpu 128@pytest.mark.env_onecard 129@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 130def test_jac_multiple_inputs_multiple_outputs_cell_graph(mode): 131 """ 132 Features: Function jacrev. 133 Description: Test ops.jacrev with multiple inputs and multiple outputs net in graph mode. 134 Expectation: No exception. 135 """ 136 context.set_context(mode=mode) 137 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 138 y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) 139 z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) 140 net = MultipleInputsMultipleOutputsNet() 141 expect_jac_0 = np.array([[[[-4, 0], [0, 0]], [[0, 6], [0, 0]]], 142 [[[0, 0], [-2, 0]], [[0, 0], [0, 4]]]]).astype(np.float32) 143 expect_jac_1 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]], 144 [[[0, 0], [10, 0]], [[0, 0], [0, -2]]]]).astype(np.float32) 145 expect_jac_2 = np.array([[[[0, 0], [0, 0]], [[0, 6], [0, 0]]], 146 [[[0, 0], [15, 0]], [[0, 0], [0, -4]]]]).astype(np.float32) 147 expect_jac_3 = np.array([[[[-2, 0], [0, 0]], [[0, 6], [0, 0]]], 148 [[[0, 0], [-3, 0]], [[0, 0], [0, 8]]]]).astype(np.float32) 149 jac = jacrev(net, grad_position=(1, 2))(x, y, z) 150 assert isinstance(jac, tuple) 151 assert len(jac) == 2 152 assert np.allclose(jac[0][0].asnumpy(), expect_jac_0) 153 assert np.allclose(jac[0][1].asnumpy(), expect_jac_1) 154 assert np.allclose(jac[1][0].asnumpy(), expect_jac_2) 155 assert np.allclose(jac[1][1].asnumpy(), expect_jac_3) 156 157 158@pytest.mark.level1 159@pytest.mark.platform_x86_cpu 160@pytest.mark.env_onecard 161@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 162def test_jac_wrap_with_jit_function_graph(mode): 163 """ 164 Features: Function jacrev. 165 Description: Test ops.jacrev warpped with @jit decorated function in graph mode. 166 Expectation: No exception. 167 """ 168 context.set_context(mode=mode) 169 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 170 y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32)) 171 z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32)) 172 expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]], 173 [[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32) 174 expect_aux = np.array([[0, 18], [-15, -8]]).astype(np.float32) 175 jac, aux = jac_wrap_with_jit_function(x, y, z) 176 assert np.allclose(jac.asnumpy(), expect_jac) 177 assert np.allclose(aux.asnumpy(), expect_aux) 178 179 180@pytest.mark.level1 181@pytest.mark.platform_x86_cpu 182@pytest.mark.env_onecard 183@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 184def test_jac_with_grad_position_twice_graph(mode): 185 """ 186 Features: Function jacrev. 187 Description: Test ops.jacrev with function setting grad_position twice in graph mode. 188 Expectation: No exception. 189 """ 190 context.set_context(mode=mode) 191 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 192 y = Tensor(np.array([[1, 3], [5, 7]]).astype(np.float32)) 193 z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 194 expect_jac_0 = np.array([[[[1, 0], [0, 0]], [[0, 3], [0, 0]]], 195 [[[0, 0], [5, 0]], [[0, 0], [0, 7]]]]).astype(np.float32) 196 expect_jac_1 = np.array([[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], 197 [[[0, 0], [3, 0]], [[0, 0], [0, 4]]]]).astype(np.float32) 198 net = MultipleInputsSingleOutputNet() 199 jac1 = jacrev(net, grad_position=0)(x, y, z) 200 jac2 = jacrev(net, grad_position=(0, 1))(x, y, z) 201 202 assert np.allclose(jac1.asnumpy(), expect_jac_0) 203 assert np.allclose(jac2[1].asnumpy(), expect_jac_1) 204 205 206@pytest.mark.level1 207@pytest.mark.platform_x86_cpu 208@pytest.mark.env_onecard 209@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 210def test_jac_with_has_aux_graph(mode): 211 """ 212 Features: Function jacrev. 213 Description: Test ops.jacrev with Cell setting grad_position in graph mode. 214 Expectation: No exception. 215 """ 216 context.set_context(mode=mode) 217 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 218 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 219 z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 220 expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]], 221 [[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32) 222 expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32) 223 net = MultipleInputsMultipleOutputsNet() 224 jac, aux = jacrev(net, grad_position=0, has_aux=True)(x, y, z) 225 assert np.allclose(jac.asnumpy(), expect_jac) 226 assert np.allclose(aux.asnumpy(), expect_aux) 227 228 229@pytest.mark.level1 230@pytest.mark.platform_x86_cpu 231@pytest.mark.env_onecard 232@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) 233def test_jac_with_function_has_aux_graph(mode): 234 """ 235 Features: Function jacrev. 236 Description: Test ops.jacrev with function setting grad_position in graph mode. 237 Expectation: No exception. 238 """ 239 context.set_context(mode=mode) 240 def fn(x, y, z): 241 return x ** 2 + y ** 2 + z ** 2, x * y * z 242 243 def fn2(*args): 244 x = args[0] 245 y = args[1] 246 z = args[2] 247 return fn(x, y, z) 248 249 x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 250 y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) 251 z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) 252 expect_jac = np.array([[[[2, 0], [0, 0]], [[0, 4], [0, 0]]], 253 [[[0, 0], [6, 0]], [[0, 0], [0, 8]]]]).astype(np.float32) 254 expect_aux = np.array([[1, 4], [9, 16]]).astype(np.float32) 255 jac, aux = jacrev(fn2, grad_position=0, has_aux=True)(x, y, z) 256 assert np.allclose(jac.asnumpy(), expect_jac) 257 assert np.allclose(aux.asnumpy(), expect_aux) 258