1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16@File : test_parse_method.py 17@Author: 18@Date : 2019-06-27 19@Desc : test parse the object's method 20""" 21import logging 22from dataclasses import dataclass 23 24import numpy as np 25import pytest 26 27import mindspore.nn as nn 28from mindspore import context 29from mindspore._extends.parse.standard_method import ms_len 30from mindspore.common.api import ms_function 31from mindspore.common.tensor import Tensor 32from mindspore.ops.composite import core 33from mindspore.ops.primitive import constexpr 34from mindspore.ops import functional as F 35from ..ut_filter import non_graph_engine 36 37 38def setup_module(module): 39 context.set_context(mode=context.PYNATIVE_MODE) 40 41 42log = logging.getLogger("test") 43log.setLevel(level=logging.ERROR) 44 45 46@ms_function 47def default_parameter_f(x, y=3): 48 """ default_parameter_f """ 49 z = x + y 50 return z 51 52 53# Test case: test parse fn that use default parameter 54def test_parse_defalut_parameter_case1(): 55 """ Test default parameter function call """ 56 log.debug("begin test_parse_defalut_parameter_case1") 57 ret = default_parameter_f(2) 58 log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret) 59 60 61def get_val_fn(x): 62 """ get_val_fn """ 63 ret = x + 3 64 return ret 65 66 67# Test case: test bool not 68@ms_function 69def bool_exp(x, y): 70 """ bool_exp """ 71 return not x > y 72 73 74def test_bool_exp(): 75 """ test_bool_exp """ 76 bool_exp(1, 2) 77 78 79# Test case: use the variable parameter for @mindspore 80@ms_function 81def var_parameter_f(x, *args): 82 """ var_parameter_f """ 83 z = x + args[0] + args[1] + args[2] 84 return z 85 86 87def test_var_parameter_case1(): 88 """ test_var_parameter_case1 """ 89 log.debug("start test_var_parameter_case1") 90 var_parameter_f(1, 2, 3, 4, 5) 91 log.debug("end test_var_parameter_case1") 92 93 94class Net(nn.Cell): 95 """ Net definition """ 96 97 def __init__(self, value1): 98 super(Net, self).__init__() 99 self.relu = nn.ReLU() 100 self.softmax = nn.Softmax(0) 101 self.axis = 0 102 self.TC = ClassTest("test_class", 1.2) 103 self.value = value1 104 105 @ms_function 106 def construct(self, x): 107 x = self.get_test_value(x) 108 return x 109 110 def get_test_value(self, x): 111 ret = x + self.value 112 return ret 113 114 115class ClassTest: 116 """ ClassTest definition """ 117 118 def __init__(self, name, value1): 119 self.name = name 120 self.value = value1 121 122 def get_name(self): 123 return self.name 124 125 def get_value(self, inc): 126 ret = self.value + inc 127 return ret 128 129 def __call__(self, *args, **kwargs): 130 pass 131 132 133# Test: call method on parse graph code 134@non_graph_engine 135def test_call_method_on_construct(): 136 """ test_call_method_on_construct """ 137 log.debug("begin test_call_method_on_construct") 138 139 x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32)) 140 y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32)) 141 z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32) 142 143 net = Net(y) 144 output = net.construct(x) 145 result = output.asnumpy() 146 print(result) 147 assert np.all(result == z) 148 149 log.debug("finished test_call_method_on_construct") 150 151 152# Test: call method on parse graph code 153class Net1(nn.Cell): 154 """ Net1 definition """ 155 156 def __init__(self, v1, v2): 157 super(Net1, self).__init__() 158 self.relu = nn.ReLU() 159 self.softmax = nn.Softmax(0) 160 self.axis = 0 161 self.TC = ClassTest("test_class", v1) 162 self.value = v2 163 164 @ms_function 165 def construct(self, x): 166 x = x + self.TC.get_value(self.value) 167 return x 168 169 170@non_graph_engine 171def test_call_other_object_method(): 172 """ test_call_other_object_method """ 173 log.debug("begin test_call_other_object_method") 174 175 x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32)) 176 y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32)) 177 y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32)) 178 z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32) 179 180 net = Net1(y, y1) 181 with pytest.raises(TypeError): 182 output = net.construct(x) 183 result = output.asnumpy() 184 print(result) 185 assert np.all(result == z) 186 187 log.debug("finished test_call_other_object_method") 188 189 190# Test: call global object method(not self) on parse graph code 191value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32)) 192TC = ClassTest("test_class", value) 193 194 195class Net2(nn.Cell): 196 """ Net2 definition """ 197 198 def __init__(self, value1): 199 super(Net2, self).__init__() 200 self.value = value1 201 202 @ms_function 203 def construct(self, x): 204 x = x + TC.get_value(self.value) 205 return x 206 207 @ms_function 208 def construct1(self, x): 209 x = x + TC.value 210 x = x + self.value 211 return x 212 213 214@non_graph_engine 215def test_call_no_self_other_object_method(): 216 """ test_call_no_self_other_object_method """ 217 log.debug("begin test_call_other_object_method") 218 x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32)) 219 y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32)) 220 z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32) 221 222 net = Net2(y) 223 with pytest.raises(TypeError): 224 output = net.construct(x) 225 result = output.asnumpy() 226 print(result) 227 assert np.all(result == z) 228 229 log.debug("finished test_call_other_object_method") 230 231 232def test_call_no_self_other_object_attr_value(): 233 """ test_call_no_self_other_object_attr_value """ 234 # do not support tensor as init input. 235 return 236 237 238# Test case: use the * to unlock the varargs for @mindspore 239def vararg1(x, y): 240 """ vararg1 """ 241 z = x + y 242 return z 243 244 245def varargs_main(fn): 246 """ varargs_main """ 247 248 @ms_function 249 def t1(*args): 250 return fn(*args) 251 252 return t1 253 254 255def test_var_parameter_case3(): 256 """ test_var_parameter_case3 """ 257 log.debug("start test_var_parameter_case3") 258 ret = varargs_main(vararg1)(1, 2) 259 log.debug("ret = %r", ret) 260 log.debug("end test_var_parameter_case3") 261 262 263# Test case: test the flag set 264@core(tg=True) 265def set_flag(x): 266 """ set_flag """ 267 return x + 1 268 269 270@ms_function 271def set_test_flag_main(x, y): 272 """ set_test_flag_main """ 273 z = set_flag(x) 274 z = z + y 275 return z 276 277 278def test_set_flag(): 279 """ Test default parameter function call """ 280 log.debug("begin test_set_flag") 281 ret = set_test_flag_main(2, 3) 282 log.debug("finished test_set_flag, ret = %r", ret) 283 284 285@dataclass 286class Access: 287 a: int 288 b: int 289 290 def max(self): 291 if self.a > self.b: 292 return self.a 293 return self.b 294 295 296@ms_function 297def invoke_dataclass(x, y): 298 """ invoke_dataclass """ 299 acs = Access(x, y) 300 return acs.max() 301 302 303def test_access(): 304 """ test_access """ 305 invoke_dataclass(1, 2) 306 307@dataclass 308class Access2: 309 a: int 310 b: int 311 312 def max(self): 313 if self.a > self.b: 314 return self.c 315 return self.b 316 317 318@ms_function 319def invoke_dataclass2(x, y): 320 """ invoke_dataclass """ 321 acs = Access2(x, y) 322 return acs.max() 323 324 325def test_access_attr_error(): 326 """ test_access """ 327 with pytest.raises(AttributeError): 328 invoke_dataclass2(2, 1) 329 330 331def myfunc(x): 332 """ myfunc """ 333 return x * x 334 335 336@ms_function 337def ms_infer_for(): 338 """ ms_infer_for """ 339 a = 0.0 340 for x in [1.1, 2.3, 3.3]: 341 a = a + x 342 return a 343 344 345def test_infer_for(): 346 """ test_infer_for """ 347 ms_infer_for() 348 349 350@ms_function 351def ms_infer_for_func(y): 352 """ ms_infer_for_func """ 353 for x in [1.0, 2.0, 3.0]: 354 y = myfunc(x) + y 355 return y 356 357 358def test_ms_infer_for_func(): 359 """ test_ms_infer_for_func """ 360 ms_infer_for_func(1.0) 361 362 363@ms_function 364def add(x, y): 365 """ add """ 366 return x + y 367 368 369def test_add(): 370 """ test_add """ 371 res = add(1, 2.0) 372 return res 373 374 375@ms_function 376def add_list(): 377 """ add_list """ 378 a = [1, 2, 3] 379 b = a[1] + a[2] 380 return b 381 382 383def test_list(): 384 """ test_list """ 385 return add_list() 386 387 388@ms_function 389def compare_list_len(): 390 """ compare_list_len """ 391 a = [1, 2, 3] 392 return ms_len(a) 393 394 395def test_list_len(): 396 """ test_list_len """ 397 compare_list_len() 398 399 400@ms_function 401def add_tuple(): 402 """ add_tuple """ 403 a = (1, 2, 3) 404 b = a[1] + a[2] 405 return b 406 407 408def test_tuple(): 409 """ test_tuple """ 410 return add_tuple() 411 412 413def invoke_func(x): 414 """ invoke_func """ 415 return x * x 416 417 418@ms_function 419def tuple_of_node(x, y): 420 """ tuple_of_node """ 421 a = invoke_func(x) 422 b = invoke_func(y) 423 c = (a, b) 424 d = c[1] * x 425 return d 426 427 428def test_tuple_node(): 429 """ test_tuple_node """ 430 res = tuple_of_node(1, 2) 431 return res 432 433 434@ms_function 435def range_spec(x, y): 436 """ range_spec """ 437 for _ in range(1, 10, 3): 438 x = x + 1 439 return x + y 440 441 442def test_range(): 443 """ test_range """ 444 res = range_spec(10, 10) 445 return res 446 447def test_expr(): 448 """ test const expr """ 449 a = (1, 2) 450 @constexpr 451 def tuple_len(x): 452 assert len(x) == 2 453 tuple_len(a) 454 455 456def test_tuple_to_array(): 457 """ test range tuple to array """ 458 range_x = range(10) 459 res = F.tuple_to_array(range_x) 460 print(res) 461