1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test lite inference python API. 17""" 18import numpy as np 19import pytest 20import mindspore_lite as mslite 21 22 23# ============================ Context ============================ 24def test_context_construct(): 25 context = mslite.Context() 26 assert "target:" in str(context) 27 28 29def test_context_target_type_error(): 30 with pytest.raises(TypeError) as raise_info: 31 context = mslite.Context() 32 context.target = 1 33 assert "target must be list" in str(raise_info.value) 34 35 36def test_context_target_list_element_type_error(): 37 with pytest.raises(TypeError) as raise_info: 38 context = mslite.Context() 39 context.target = [1] 40 assert "target element must be str" in str(raise_info.value) 41 42 43def test_context_target_list_element_value_error(): 44 with pytest.raises(ValueError) as raise_info: 45 context = mslite.Context() 46 context.target = ["1"] 47 assert "target elements must be in" in str(raise_info.value) 48 49 50def test_context_target(): 51 context = mslite.Context() 52 context.target = ["cpu"] 53 assert context.target == ["cpu"] 54 context.target = ["gpu"] 55 assert context.target == ["gpu"] 56 context.target = ["ascend"] 57 assert context.target == ["ascend"] 58 context.target = [] 59 assert context.target == ["cpu"] 60 61 62def test_context_cpu_precision_mode_type_error(): 63 with pytest.raises(TypeError) as raise_info: 64 context = mslite.Context() 65 context.cpu.precision_mode = 1 66 assert "cpu_precision_mode must be str" in str(raise_info.value) 67 68 69def test_context_cpu_precision_mode_value_error(): 70 with pytest.raises(ValueError) as raise_info: 71 context = mslite.Context() 72 context.cpu.precision_mode = "1" 73 assert "cpu_precision_mode must be in" in str(raise_info.value) 74 75 76def test_context_cpu_precision_mode(): 77 context = mslite.Context() 78 context.cpu.precision_mode = "preferred_fp16" 79 assert "precision_mode: preferred_fp16" in str(context.cpu) 80 81 82def test_context_cpu_thread_num_type_error(): 83 with pytest.raises(TypeError) as raise_info: 84 context = mslite.Context() 85 context.cpu.thread_num = "1" 86 assert "cpu_thread_num must be int" in str(raise_info.value) 87 88 89def test_context_cpu_thread_num_negative_value_error(): 90 with pytest.raises(ValueError) as raise_info: 91 context = mslite.Context() 92 context.cpu.thread_num = -1 93 assert "cpu_thread_num must be a non-negative int" in str(raise_info.value) 94 95 96def test_context_cpu_thread_num(): 97 context = mslite.Context() 98 context.cpu.thread_num = 4 99 assert "thread_num: 4" in str(context.cpu) 100 101 102def test_context_cpu_inter_op_parallel_num_type_error(): 103 with pytest.raises(TypeError) as raise_info: 104 context = mslite.Context() 105 context.cpu.inter_op_parallel_num = "1" 106 assert "cpu_inter_op_parallel_num must be int" in str(raise_info.value) 107 108 109def test_context_cpu_inter_op_parallel_num_negative_error(): 110 with pytest.raises(ValueError) as raise_info: 111 context = mslite.Context() 112 context.cpu.inter_op_parallel_num = -1 113 assert "cpu_inter_op_parallel_num must be a non-negative int" in str(raise_info.value) 114 115 116def test_context_cpu_inter_op_parallel_num(): 117 context = mslite.Context() 118 context.cpu.inter_op_parallel_num = 1 119 assert "inter_op_parallel_num: 1" in str(context.cpu) 120 121 122def test_context_cpu_thread_affinity_mode_type_error(): 123 with pytest.raises(TypeError) as raise_info: 124 context = mslite.Context() 125 context.cpu.thread_affinity_mode = "1" 126 assert "cpu_thread_affinity_mode must be int" in str(raise_info.value) 127 128 129def test_context_cpu_thread_affinity_mode(): 130 context = mslite.Context() 131 context.cpu.thread_affinity_mode = 2 132 assert "thread_affinity_mode: 2" in str(context.cpu) 133 134 135def test_context_cpu_thread_affinity_core_list_type_error(): 136 with pytest.raises(TypeError) as raise_info: 137 context = mslite.Context() 138 context.cpu.thread_affinity_core_list = 2 139 assert "cpu_thread_affinity_core_list must be list" in str(raise_info.value) 140 141 142def test_context_cpu_thread_affinity_core_list_element_type_error(): 143 with pytest.raises(TypeError) as raise_info: 144 context = mslite.Context() 145 context.cpu.thread_affinity_core_list = ["1", "0"] 146 assert "cpu_thread_affinity_core_list element must be int" in str(raise_info.value) 147 148 149def test_context_cpu_thread_affinity_core_list(): 150 context = mslite.Context() 151 context.cpu.thread_affinity_core_list = [2] 152 assert "thread_affinity_core_list: [2]" in str(context.cpu) 153 context.cpu.thread_affinity_core_list = [1, 0] 154 assert "thread_affinity_core_list: [1, 0]" in str(context.cpu) 155 156 157def test_context_gpu_precision_mode_type_error(): 158 with pytest.raises(TypeError) as raise_info: 159 context = mslite.Context() 160 context.gpu.precision_mode = 1 161 assert "gpu_precision_mode must be str" in str(raise_info.value) 162 163 164def test_context_gpu_precision_mode_value_error(): 165 with pytest.raises(ValueError) as raise_info: 166 context = mslite.Context() 167 context.gpu.precision_mode = "1" 168 assert "gpu_precision_mode must be in" in str(raise_info.value) 169 170 171def test_context_gpu_precision_mode(): 172 context = mslite.Context() 173 context.gpu.precision_mode = "preferred_fp16" 174 assert "precision_mode: preferred_fp16" in str(context.gpu) 175 176 177def test_context_gpu_device_id_type_error(): 178 with pytest.raises(TypeError) as raise_info: 179 context = mslite.Context() 180 context.gpu.device_id = "1" 181 assert "gpu_device_id must be int" in str(raise_info.value) 182 183 184def test_context_gpu_device_id_negative_error(): 185 with pytest.raises(ValueError) as raise_info: 186 context = mslite.Context() 187 context.gpu.device_id = -1 188 assert "gpu_device_id must be a non-negative int" in str(raise_info.value) 189 190 191def test_context_gpu_device_id(): 192 context = mslite.Context() 193 context.gpu.device_id = 1 194 assert "device_id: 1" in str(context.gpu) 195 196 197def test_context_ascend_precision_mode_value_error(): 198 with pytest.raises(ValueError) as raise_info: 199 context = mslite.Context() 200 context.ascend.precision_mode = "1" 201 assert "ascend_precision_mode must be in" in str(raise_info.value) 202 203 204def test_context_ascend_precision_mode(): 205 context = mslite.Context() 206 context.ascend.precision_mode = "enforce_fp32" 207 assert "precision_mode: enforce_fp32" in str(context.ascend) 208 209 210def test_context_ascend_device_id_type_error(): 211 with pytest.raises(TypeError) as raise_info: 212 context = mslite.Context() 213 context.ascend.device_id = "1" 214 assert "ascend_device_id must be int" in str(raise_info.value) 215 216 217def test_context_ascend_device_id_negative_error(): 218 with pytest.raises(ValueError) as raise_info: 219 context = mslite.Context() 220 context.ascend.device_id = -1 221 assert "ascend_device_id must be a non-negative int" in str(raise_info.value) 222 223 224def test_context_ascend_device_id(): 225 context = mslite.Context() 226 context.ascend.device_id = 1 227 assert "device_id: 1" in str(context.ascend) 228 229 230def test_context_ascend_provider_type_error(): 231 with pytest.raises(TypeError) as raise_info: 232 context = mslite.Context() 233 context.ascend.provider = 1 234 assert "ascend_provider must be str" in str(raise_info.value) 235 236 237def test_context_ascend_provider(): 238 context = mslite.Context() 239 context.ascend.provider = "ge" 240 assert context.ascend.provider == "ge" 241 assert "provider: ge" in str(context.ascend) 242 243 244def test_context_ascend_rank_id_type_error(): 245 with pytest.raises(TypeError) as raise_info: 246 context = mslite.Context() 247 context.ascend.rank_id = "1" 248 assert "ascend_rank_id must be int" in str(raise_info.value) 249 250 251def test_context_ascend_rank_id_negative_error(): 252 with pytest.raises(ValueError) as raise_info: 253 context = mslite.Context() 254 context.ascend.rank_id = -1 255 assert "ascend_rank_id must be a non-negative int" in str(raise_info.value) 256 257 258def test_context_ascend_rank_id(): 259 context = mslite.Context() 260 context.ascend.rank_id = 1 261 assert "rank_id: 1" in str(context.ascend) 262 263 264# ============================ Model ============================ 265def test_model_01(): 266 model = mslite.Model() 267 assert "model_path:" in str(model) 268 269 270def test_model_build_from_file_model_path_type_error(): 271 with pytest.raises(TypeError) as raise_info: 272 model = mslite.Model() 273 model.build_from_file(model_path=1, model_type=mslite.ModelType.MINDIR_LITE) 274 assert "model_path must be str" in str(raise_info.value) 275 276 277def test_model_build_from_file_model_path_not_exist_error(): 278 with pytest.raises(RuntimeError) as raise_info: 279 model = mslite.Model() 280 model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE) 281 assert "model_path does not exist" in str(raise_info.value) 282 283 284def test_model_build_from_file_model_type_type_error(): 285 with pytest.raises(TypeError) as raise_info: 286 model = mslite.Model() 287 model.build_from_file(model_path="test.ms", model_type="MINDIR_LITE") 288 assert "model_type must be ModelType" in str(raise_info.value) 289 290 291def test_model_build_from_file_context_type_error(): 292 with pytest.raises(TypeError) as raise_info: 293 cpu_device_info = mslite.Context().cpu 294 model = mslite.Model() 295 model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=cpu_device_info) 296 assert "context must be Context" in str(raise_info.value) 297 298 299def test_model_build_from_file_config_path_type_error(): 300 with pytest.raises(TypeError) as raise_info: 301 model = mslite.Model() 302 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 303 config_path=1) 304 assert "config_path must be str" in str(raise_info.value) 305 306 307def test_model_build_from_file_config_path_not_exist_error(): 308 with pytest.raises(RuntimeError) as raise_info: 309 model = mslite.Model() 310 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 311 config_path="test.cfg") 312 assert "config_path does not exist" in str(raise_info.value) 313 314 315def test_model_build_from_file_config_dict_type_error(): 316 with pytest.raises(TypeError) as raise_info: 317 model = mslite.Model() 318 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 319 config_dict="test.cfg") 320 assert "config_dict must be dict" in str(raise_info.value) 321 322 323def test_model_build_from_file_config_dict_key_type_error(): 324 with pytest.raises(TypeError) as raise_info: 325 model = mslite.Model() 326 dict_0 = {5: {"1": "2"}} 327 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 328 config_dict=dict_0) 329 assert "config_dict_key must be str" in str(raise_info.value) 330 331 332def test_model_build_from_file_config_dict_value_type_error(): 333 with pytest.raises(TypeError) as raise_info: 334 model = mslite.Model() 335 dict_1 = {"5": "6"} 336 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 337 config_dict=dict_1) 338 assert "config_dict_value must be dict" in str(raise_info.value) 339 340 341def test_model_build_from_file_config_dict_value_key_type_error(): 342 with pytest.raises(TypeError) as raise_info: 343 model = mslite.Model() 344 dict_2 = {"5": {3: "2"}} 345 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 346 config_dict=dict_2) 347 assert "config_dict_value_key must be str" in str(raise_info.value) 348 349 350def test_model_build_from_file_config_dict_value_value_type_error(): 351 with pytest.raises(TypeError) as raise_info: 352 model = mslite.Model() 353 dict_3 = {"5": {"1": 2}} 354 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, 355 config_dict=dict_3) 356 assert "config_dict_value_value must be str" in str(raise_info.value) 357 358 359def get_model(): 360 context = mslite.Context() 361 context.target = ["cpu"] 362 context.cpu.thread_num = 2 363 model = mslite.Model() 364 model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context) 365 return model 366 367 368def test_model_resize_inputs_type_error(): 369 with pytest.raises(TypeError) as raise_info: 370 model = get_model() 371 inputs = model.get_inputs() 372 model.resize(inputs[0], [[1, 112, 112, 3]]) 373 assert "inputs must be list" in str(raise_info.value) 374 375 376def test_model_resize_inputs_elements_type_error(): 377 with pytest.raises(TypeError) as raise_info: 378 model = get_model() 379 model.resize([1, 2], [[1, 112, 112, 3]]) 380 assert "inputs element must be Tensor" in str(raise_info.value) 381 382 383def test_model_resize_dims_type_error(): 384 with pytest.raises(TypeError) as raise_info: 385 model = get_model() 386 inputs = model.get_inputs() 387 model.resize(inputs, "[[1, 112, 112, 3]]") 388 assert "dims must be list" in str(raise_info.value) 389 390 391def test_model_resize_dims_elements_type_error(): 392 with pytest.raises(TypeError) as raise_info: 393 model = get_model() 394 inputs = model.get_inputs() 395 model.resize(inputs, ["[1, 112, 112, 3]"]) 396 assert "dims element must be list" in str(raise_info.value) 397 398 399def test_model_resize_dims_elements_elements_type_error(): 400 with pytest.raises(TypeError) as raise_info: 401 model = get_model() 402 inputs = model.get_inputs() 403 model.resize(inputs, [[1, "112", 112, 3]]) 404 assert "dims element's element must be int" in str(raise_info.value) 405 406 407def test_model_resize_inputs_size_not_equal_dims_size_error(): 408 with pytest.raises(ValueError) as raise_info: 409 model = get_model() 410 inputs = model.get_inputs() 411 model.resize(inputs, [[1, 112, 112, 3], [1, 112, 112, 3]]) 412 assert "inputs' size does not match dims' size" in str(raise_info.value) 413 414 415def test_model_resize_01(): 416 model = get_model() 417 inputs = model.get_inputs() 418 assert inputs[0].shape == [1, 224, 224, 3] 419 model.resize(inputs, [[1, 112, 112, 3]]) 420 assert inputs[0].shape == [1, 112, 112, 3] 421 422 423def test_model_predict_inputs_type_error(): 424 with pytest.raises(TypeError) as raise_info: 425 model = get_model() 426 inputs = model.get_inputs() 427 outputs = model.predict(inputs[0]) 428 assert "inputs must be list" in str(raise_info.value) 429 430 431def test_model_predict_inputs_element_type_error(): 432 with pytest.raises(TypeError) as raise_info: 433 model = get_model() 434 outputs = model.predict(["input"]) 435 assert "inputs element must be Tensor" in str(raise_info.value) 436 437def test_model_get_model_info_type_error(): 438 with pytest.raises(TypeError) as raise_info: 439 model = get_model() 440 inputs = model.get_model_info() 441 assert "key must be str" in str(raise_info.value) 442 443def test_model_predict_runtime_error(): 444 with pytest.raises(RuntimeError) as raise_info: 445 model = get_model() 446 inputs = model.get_inputs() 447 outputs = model.predict(inputs) 448 assert "predict failed" in str(raise_info.value) 449 450 451def test_model_predict_01(): 452 model = get_model() 453 inputs = model.get_inputs() 454 in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3)) 455 inputs[0].set_data_from_numpy(in_data) 456 outputs = model.predict(inputs) 457 458 459def test_model_predict_02(): 460 model = get_model() 461 inputs = model.get_inputs() 462 input_tensor = mslite.Tensor() 463 input_tensor.dtype = inputs[0].dtype 464 input_tensor.shape = inputs[0].shape 465 input_tensor.format = inputs[0].format 466 input_tensor.name = inputs[0].name 467 in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3)) 468 input_tensor.set_data_from_numpy(in_data) 469 outputs = model.predict([input_tensor]) 470 471 472# ============================ Tensor ============================ 473def test_tensor_type_error(): 474 tensor1 = mslite.Tensor() 475 tensor2 = mslite.Tensor(tensor=tensor1) # now supported 476 477 478def test_tensor(): 479 tensor1 = mslite.Tensor() 480 assert tensor1.name == "" 481 482 483def test_tensor_name_type_error(): 484 with pytest.raises(TypeError) as raise_info: 485 tensor = mslite.Tensor() 486 tensor.name = 1 487 assert "name must be str" in str(raise_info.value) 488 489 490def test_tensor_name(): 491 tensor = mslite.Tensor() 492 tensor.name = "tensor0" 493 assert tensor.name == "tensor0" 494 495 496def test_tensor_dtype_type_error(): 497 with pytest.raises(TypeError) as raise_info: 498 tensor = mslite.Tensor() 499 tensor.dtype = 1 500 assert "dtype must be DataType" in str(raise_info.value) 501 502 503def test_tensor_dtype(): 504 tensor = mslite.Tensor() 505 tensor.dtype = mslite.DataType.INT32 506 assert tensor.dtype == mslite.DataType.INT32 507 508 509def test_tensor_shape_type_error(): 510 with pytest.raises(TypeError) as raise_info: 511 tensor = mslite.Tensor() 512 tensor.shape = 224 513 assert "shape must be list" in str(raise_info.value) 514 515 516def test_tensor_shape_element_type_error(): 517 with pytest.raises(TypeError) as raise_info: 518 tensor = mslite.Tensor() 519 tensor.shape = ["224", "224"] 520 assert "shape element must be int" in str(raise_info.value) 521 522 523def test_tensor_shape_get_element_num_get_data_size_01(): 524 tensor = mslite.Tensor() 525 tensor.dtype = mslite.DataType.FLOAT32 526 tensor.shape = [16, 16] 527 assert tensor.shape == [16, 16] 528 assert tensor.element_num == 256 529 assert tensor.data_size == 1024 530 531 532def test_tensor_format_type_error(): 533 with pytest.raises(TypeError) as raise_info: 534 tensor = mslite.Tensor() 535 tensor.format = 1 536 assert "format must be Format" in str(raise_info.value) 537 538 539def test_tensor_format(): 540 tensor = mslite.Tensor() 541 tensor.format = mslite.Format.NHWC4 542 assert tensor.format == mslite.Format.NHWC4 543 544 545def test_tensor_set_data_from_numpy_numpy_obj_type_error(): 546 with pytest.raises(TypeError) as raise_info: 547 tensor = mslite.Tensor() 548 tensor.set_data_from_numpy(1) 549 assert "numpy_obj must be numpy.ndarray," in str(raise_info.value) 550 551 552def test_tensor_set_data_from_numpy_data_type_not_equal_error(): 553 with pytest.raises(RuntimeError) as raise_info: 554 tensor = mslite.Tensor() 555 tensor.dtype = mslite.DataType.FLOAT32 556 tensor.shape = [2, 3] 557 in_data = np.arange(2 * 3, dtype=np.int32).reshape((2, 3)) 558 tensor.set_data_from_numpy(in_data) 559 assert "data type not equal" in str(raise_info.value) 560 561 562def test_tensor_set_data_from_numpy_data_size_not_equal_error(): 563 with pytest.raises(RuntimeError) as raise_info: 564 tensor = mslite.Tensor() 565 tensor.dtype = mslite.DataType.FLOAT32 566 in_data = np.arange(2 * 3, dtype=np.float32).reshape((2, 3)) 567 tensor.set_data_from_numpy(in_data) 568 assert "data size not equal" in str(raise_info.value) 569 570 571def test_tensor_set_data_from_numpy(): 572 tensor = mslite.Tensor() 573 tensor.dtype = mslite.DataType.FLOAT32 574 tensor.shape = [2, 3] 575 in_data = np.arange(2 * 3, dtype=np.float32).reshape((2, 3)) 576 tensor.set_data_from_numpy(in_data) 577 out_data = tensor.get_data_to_numpy() 578 assert (out_data == in_data).all() 579 580 581def test_model_group_invalid_flags_error(): 582 with pytest.raises(RuntimeError) as raise_info: 583 _ = mslite.ModelGroup(flags=1001) 584 assert "Parameter flags should be ModelGroupFlag.SHARE_WORKSPACE or" in str(raise_info.value) 585 586 587def test_model_group_add_model_share_workspace_add_model_obj_error(): 588 with pytest.raises(RuntimeError) as raise_info: 589 model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WORKSPACE) 590 model0 = mslite.Model() 591 model1 = mslite.Model() 592 model_group.add_model([model0, model1]) 593 assert "ModelGroup's add model failed." in str(raise_info.value) 594 595 596def test_model_group_add_model_share_weight_add_model_path_error(): 597 with pytest.raises(RuntimeError) as raise_info: 598 model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT) 599 model_group.add_model(["model0_path", "model1_path"]) 600 assert "ModelGroup's add model failed." in str(raise_info.value) 601 602 603def test_model_group_add_model_invalid_model_path_with_model_obj_error(): 604 with pytest.raises(TypeError) as raise_info: 605 model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT) 606 model1 = mslite.Model() 607 model_group.add_model(["model_path", model1]) 608 assert "models element must be all str or Model" in str(raise_info.value) 609 610 611def test_model_group_add_model_invalid_model_obj_with_model_path_error(): 612 with pytest.raises(TypeError) as raise_info: 613 model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT) 614 model1 = mslite.Model() 615 model_group.add_model([model1, "model_path"]) 616 assert "models element must be all str or Model" in str(raise_info.value) 617 618 619def test_model_group_add_model_invalid_model_obj_type_error(): 620 with pytest.raises(TypeError) as raise_info: 621 model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT) 622 model_group.add_model("model_path") 623 assert "models must be list/tuple, but got" in str(raise_info.value) 624