1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test graph fallback """ 16import pytest 17import numpy as np 18 19import mindspore.nn as nn 20from mindspore import Tensor, jit, context 21from mindspore.ops import operations as P 22from mindspore.ops import functional as F 23from mindspore.nn.probability import distribution 24import mindspore.common.dtype as mstype 25import mindspore.common._monad as monad 26import mindspore.scipy.linalg as alg 27 28context.set_context(mode=context.GRAPH_MODE) 29 30# `add_func` is defined in current file. 31def add_func(x, y): 32 return x + y 33 34 35@jit 36def do_increment(i): 37 add_1 = F.partial(add_func, 1) 38 return add_1(i) 39 40 41def test_increment(): 42 a = do_increment(9) 43 assert a == 10 44 45 46@jit 47def use_monad(x, y): 48 res = P.Mul()(x, y) 49 res = F.depend(res, monad.U) 50 return res 51 52 53def test_use_monad(): 54 x = Tensor(1.0, mstype.float32) 55 y = Tensor(1.0, mstype.float32) 56 print(use_monad(x, y)) 57 58 59@jit 60def use_tuple_of_tensor(): 61 me_x = (Tensor(1), Tensor(1)) 62 return me_x 63 64 65def test_tuple_of_tensor(): 66 """ 67 Feature: JIT Fallback 68 Description: Test tuple of tensor in graph mode. 69 Expectation: No exception. 70 """ 71 print(use_tuple_of_tensor()) 72 73 74@jit 75def use_list_of_tensor(): 76 me_x = [Tensor(1), Tensor(1)] 77 return me_x 78 79 80def test_list_of_tensor(): 81 """ 82 Feature: JIT Fallback 83 Description: Test list of tensor in graph mode. 84 Expectation: No exception. 85 """ 86 print(use_list_of_tensor()) 87 88 89class Net(nn.Cell): 90 def __init__(self): 91 super(Net, self).__init__() 92 self.x = Tensor([2, 3, 4]) 93 94 def construct(self): 95 x_len = len(self.x) 96 for i in range(x_len): 97 print(i) 98 return x_len 99 100 101def test_builtins_len(): 102 net = Net() 103 net() 104 105 106@jit 107def np_fallback_func(): 108 array_x = tuple([2, 3, 4, 5]) 109 np_x = np.array(array_x).astype(np.float32) 110 me_x = Tensor(np_x) 111 me_x = me_x + me_x 112 return me_x 113 114 115def test_np_fallback_func(): 116 print(np_fallback_func()) 117 118 119# Test `return` interpret node. 120@jit 121def div_mod_func1(): 122 x = 8 123 y = 3 124 a = divmod(x, y) 125 return Tensor(a) 126 127 128def test_div_mod_func1(): 129 print(div_mod_func1()) # (2, 2) 130 131 132# Test interpret node with parameters as input. 133@jit 134def div_mod_func2(x, y): 135 a = divmod(x, y) 136 return Tensor(a) 137 138 139def test_div_mod_func2_scalar(): 140 """ 141 Feature: JIT Fallback 142 Description: Test divmod in graph. 143 Expectation: No exception. 144 """ 145 print(div_mod_func2(8, 3)) # (2, 2) 146 147 148@pytest.mark.skip(reason='Not support in graph jit fallback feature yet') 149def test_div_mod_func2_tensor(): 150 """ 151 Feature: JIT Fallback 152 Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution. 153 Expectation: Not supported exception. 154 """ 155 with pytest.raises(RuntimeError) as err: 156 print(div_mod_func2(Tensor(8), Tensor(3))) 157 assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value) 158 159 160@jit 161def select_func(cond, x, y): 162 if isinstance(cond, (tuple, list)): 163 output = y 164 elif isinstance(cond, Tensor): 165 output = F.select(cond, x, y) 166 else: 167 output = x 168 return output 169 170 171def test_select_func(): 172 cond = Tensor([True, False]) 173 x = Tensor([2, 3], mstype.float32) 174 y = Tensor([1, 2], mstype.float32) 175 print(select_func(cond, x, y)) 176 177 178@jit 179def select_func2(cond, x, y): 180 if isinstance(cond, (tuple, list)): 181 output = y 182 if isinstance(cond, Tensor): 183 output = F.select(cond, x, y) 184 else: 185 output = x 186 return output 187 188 189def test_select_func2(): 190 cond = Tensor([True, False]) 191 x = Tensor([2, 3], mstype.float32) 192 y = Tensor([1, 2], mstype.float32) 193 print(select_func2(cond, x, y)) 194 195 196@jit 197def slice_func(a, b): 198 a[1:3, ::] = b 199 return a 200 201 202def test_slice_func(): 203 a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32) 204 b = Tensor([1], dtype=mstype.float32) 205 print(slice_func(a, b)) 206 207 208def test_context(): 209 """ 210 Feature: JIT Fallback 211 Description: Test context in graph. 212 Expectation: No exception. 213 """ 214 class ContextNet(nn.Cell): 215 def __init__(self): 216 super(ContextNet, self).__init__() 217 self.mode = context.get_context("mode") 218 219 def construct(self): 220 out = 1 221 if self.mode == context.GRAPH_MODE: 222 out = 2 223 return out 224 225 net = ContextNet() 226 out = net() 227 print(out) 228 229 230def test_scipy_module(): 231 """ 232 Feature: JIT Fallback 233 Description: Test scipy module in graph. 234 Expectation: No exception. 235 """ 236 class Network(nn.Cell): 237 def construct(self, x): 238 return alg.eigh(x) 239 240 net = Network() 241 x = Tensor([[2, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]]) 242 out = net(x) 243 print(out) 244 245 246def test_probability_cauchy(): 247 """ 248 Feature: JIT Fallback 249 Description: NumPy method is called in probability cauchy. 250 Expectation: No exception. 251 """ 252 class CauchyProb(nn.Cell): 253 def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'): 254 super().__init__() 255 self.b = distribution.Cauchy(loc, scale, seed, dtype, name) 256 257 def construct(self, value, loc=None, scale=None): 258 out1 = self.b.prob(value, loc, scale) 259 out2 = self.b.log_prob(value, loc, scale) 260 out3 = self.b.cdf(value, loc, scale) 261 out4 = self.b.log_cdf(value, loc, scale) 262 out5 = self.b.survival_function(value, loc, scale) 263 out6 = self.b.log_survival(value, loc, scale) 264 return out1, out2, out3, out4, out5, out6 265 266 267 loc = np.random.randn(1024, 512, 7, 7).astype(np.float32) 268 scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32) 269 loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32) 270 scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32) 271 value = np.random.randn(1024, 512, 7, 7).astype(np.float32) 272 273 net = CauchyProb(loc, scale) 274 net(Tensor(value), Tensor(loc_a), Tensor(scale_a)) 275