1# Copyright 2021-2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test graph fallback """ 16import mindspore.common.dtype as mstype 17import mindspore.nn as nn 18from mindspore import Tensor, jit, context 19from mindspore import ops, tensor 20 21import numpy as np 22import pytest 23 24context.set_context(mode=context.GRAPH_MODE) 25 26 27class ControlNet(nn.Cell): 28 def inner_function_1(self, a, b): 29 return a + b 30 31 def inner_function_2(self, a, b): 32 return a - b 33 34 def construct(self, x): 35 a = Tensor(np.array(4), mstype.int32) 36 b = Tensor(np.array(5), mstype.int32) 37 if a + b > x: 38 return self.inner_function_1(a, b) 39 return self.inner_function_2(a, b) 40 41 42@pytest.mark.level2 43@pytest.mark.platform_x86_gpu_training 44@pytest.mark.platform_arm_ascend_training 45@pytest.mark.platform_x86_ascend_training 46@pytest.mark.env_onecard 47def test_fallback_control_sink_tensor(): 48 """ 49 Feature: Fallback feature: support define Tensor in Class construct. 50 Description: Fallback feature: support define Tensor in Class construct. 51 Expectation: Fallback feature: support define Tensor in Class construct. 52 """ 53 x = Tensor(np.array(1), mstype.int32) 54 net = ControlNet() 55 output = net(x) 56 output_expect = Tensor(9, mstype.int32) 57 assert output == output_expect 58 59 60@pytest.mark.level1 61@pytest.mark.platform_x86_gpu_training 62@pytest.mark.platform_arm_ascend_training 63@pytest.mark.platform_x86_ascend_training 64@pytest.mark.env_onecard 65def test_np_tensor_list(): 66 """ 67 Feature: Fallback feature 68 Description: support Basic method of Tensor list. 69 Expectation: No exception. 70 """ 71 72 @jit 73 def np_tensor_list(): 74 a = Tensor(np.array(4), mstype.int32) 75 b = Tensor(np.array(5), mstype.int32) 76 c = Tensor(np.array(6), mstype.int32) 77 tensor_list = [a, b] 78 for x in tensor_list: 79 print(x) 80 tensor_list.append(tensor_list[-1] + c) 81 return tensor_list 82 83 tensor_list = np_tensor_list() 84 print("tensor_list:", tensor_list) 85 assert len(tensor_list) == 3 86 87 88@jit 89def np_fallback_func_tensor_index(x): 90 array_x = tuple([2, 3, 4, 5]) 91 np_x = np.array(array_x).astype(np.float32) 92 me_x = Tensor(np_x) 93 me_x = me_x + me_x 94 return me_x[x] 95 96 97@pytest.mark.level2 98@pytest.mark.platform_x86_gpu_training 99@pytest.mark.platform_arm_ascend_training 100@pytest.mark.platform_x86_ascend_training 101@pytest.mark.env_onecard 102def test_np_fallback_func_tensor_index(): 103 """ 104 Feature: Fallback feature: support Tensor index. 105 Description: Fallback feature: support Tensor index. 106 Expectation: Fallback feature: support Tensor index. 107 """ 108 x = Tensor(1, mstype.int32) 109 output = np_fallback_func_tensor_index(x) 110 output_expect = Tensor(6, mstype.float32) 111 assert output == output_expect 112 113 114@pytest.mark.level2 115@pytest.mark.platform_x86_gpu_training 116@pytest.mark.platform_arm_ascend_training 117@pytest.mark.platform_x86_ascend_training 118@pytest.mark.env_onecard 119def test_fallback_tensor_compare_with_variable(): 120 """ 121 Feature: Fallback feature 122 Description: Test ms.Tensor() in graph mode. 123 Expectation: No exception. 124 """ 125 126 @jit 127 def foo(x): 128 while x > Tensor([0]): 129 x = x - abs(Tensor([-1])) 130 return x 131 132 res = foo(Tensor([6])) 133 assert res == 0 134 135 136@pytest.mark.level1 137@pytest.mark.platform_x86_gpu_training 138@pytest.mark.platform_arm_ascend_training 139@pytest.mark.platform_x86_ascend_training 140@pytest.mark.env_onecard 141def test_np_tensor_add(): 142 """ 143 Feature: Fallback feature 144 Description: support Tensor add. 145 Expectation: No exception. 146 """ 147 148 @jit 149 def np_tensor_add(): 150 a = Tensor(np.array(4)) 151 b = Tensor(np.array(5)) 152 tensor_list = [a, b] 153 for x in tensor_list: 154 print(x) 155 x = 6 156 np_x = np.array(x) 157 c = Tensor(np_x) 158 d = tensor_list[-1] + c 159 tensor_list.append(d) 160 return tensor_list 161 162 tensor_list = np_tensor_add() 163 print("tensor_list:", tensor_list) 164 assert tensor_list[-1] == 11 165 166 167@pytest.mark.level0 168@pytest.mark.platform_x86_gpu_training 169@pytest.mark.env_onecard 170def test_user_define_bprop_using_fallback(): 171 """ 172 Feature: Fallback feature 173 Description: user define bprop support jit fallback. 174 Expectation: No exception. 175 """ 176 class TestBpropCell(nn.Cell): 177 def __init__(self): 178 super().__init__() 179 self.const_value = 1 180 181 def construct(self, x): 182 x = x * self.const_value 183 x = x.asnumpy() 184 x = (x + x) * x 185 return tensor(x, mstype.float32) 186 187 def bprop(self, x, out, dout): 188 x = dout.asnumpy() 189 x = 2 * (x * x) * (np.log(x) + 1) 190 return (tensor(x, mstype.float32),) 191 192 class TestCell(nn.Cell): 193 def __init__(self): 194 super().__init__() 195 self.user_define_bprop = TestBpropCell() 196 197 def construct(self, x): 198 x = 2 * x 199 x = self.user_define_bprop(x) 200 x = x + 1 201 x = 2 * x 202 return x 203 204 test_cell = TestCell() 205 input_x = Tensor([1, 2, 3, 4], mstype.float32) 206 graph_output = ops.grad(test_cell)(input_x) 207 208 context.set_context(mode=context.PYNATIVE_MODE) 209 pynative_out = ops.grad(test_cell)(input_x) 210 context.set_context(mode=context.GRAPH_MODE) 211 212 assert np.allclose(graph_output.asnumpy(), pynative_out.asnumpy()) 213