• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test lite inference python API.
17"""
18import numpy as np
19import pytest
20import mindspore_lite as mslite
21
22
23# ============================ Context ============================
24def test_context_construct():
25    context = mslite.Context()
26    assert "target:" in str(context)
27
28
29def test_context_target_type_error():
30    with pytest.raises(TypeError) as raise_info:
31        context = mslite.Context()
32        context.target = 1
33    assert "target must be list" in str(raise_info.value)
34
35
36def test_context_target_list_element_type_error():
37    with pytest.raises(TypeError) as raise_info:
38        context = mslite.Context()
39        context.target = [1]
40    assert "target element must be str" in str(raise_info.value)
41
42
43def test_context_target_list_element_value_error():
44    with pytest.raises(ValueError) as raise_info:
45        context = mslite.Context()
46        context.target = ["1"]
47    assert "target elements must be in" in str(raise_info.value)
48
49
50def test_context_target():
51    context = mslite.Context()
52    context.target = ["cpu"]
53    assert context.target == ["cpu"]
54    context.target = ["gpu"]
55    assert context.target == ["gpu"]
56    context.target = ["ascend"]
57    assert context.target == ["ascend"]
58    context.target = []
59    assert context.target == ["cpu"]
60
61
62def test_context_cpu_precision_mode_type_error():
63    with pytest.raises(TypeError) as raise_info:
64        context = mslite.Context()
65        context.cpu.precision_mode = 1
66    assert "cpu_precision_mode must be str" in str(raise_info.value)
67
68
69def test_context_cpu_precision_mode_value_error():
70    with pytest.raises(ValueError) as raise_info:
71        context = mslite.Context()
72        context.cpu.precision_mode = "1"
73    assert "cpu_precision_mode must be in" in str(raise_info.value)
74
75
76def test_context_cpu_precision_mode():
77    context = mslite.Context()
78    context.cpu.precision_mode = "preferred_fp16"
79    assert "precision_mode: preferred_fp16" in str(context.cpu)
80
81
82def test_context_cpu_thread_num_type_error():
83    with pytest.raises(TypeError) as raise_info:
84        context = mslite.Context()
85        context.cpu.thread_num = "1"
86    assert "cpu_thread_num must be int" in str(raise_info.value)
87
88
89def test_context_cpu_thread_num_negative_value_error():
90    with pytest.raises(ValueError) as raise_info:
91        context = mslite.Context()
92        context.cpu.thread_num = -1
93    assert "cpu_thread_num must be a non-negative int" in str(raise_info.value)
94
95
96def test_context_cpu_thread_num():
97    context = mslite.Context()
98    context.cpu.thread_num = 4
99    assert "thread_num: 4" in str(context.cpu)
100
101
102def test_context_cpu_inter_op_parallel_num_type_error():
103    with pytest.raises(TypeError) as raise_info:
104        context = mslite.Context()
105        context.cpu.inter_op_parallel_num = "1"
106    assert "cpu_inter_op_parallel_num must be int" in str(raise_info.value)
107
108
109def test_context_cpu_inter_op_parallel_num_negative_error():
110    with pytest.raises(ValueError) as raise_info:
111        context = mslite.Context()
112        context.cpu.inter_op_parallel_num = -1
113    assert "cpu_inter_op_parallel_num must be a non-negative int" in str(raise_info.value)
114
115
116def test_context_cpu_inter_op_parallel_num():
117    context = mslite.Context()
118    context.cpu.inter_op_parallel_num = 1
119    assert "inter_op_parallel_num: 1" in str(context.cpu)
120
121
122def test_context_cpu_thread_affinity_mode_type_error():
123    with pytest.raises(TypeError) as raise_info:
124        context = mslite.Context()
125        context.cpu.thread_affinity_mode = "1"
126    assert "cpu_thread_affinity_mode must be int" in str(raise_info.value)
127
128
129def test_context_cpu_thread_affinity_mode():
130    context = mslite.Context()
131    context.cpu.thread_affinity_mode = 2
132    assert "thread_affinity_mode: 2" in str(context.cpu)
133
134
135def test_context_cpu_thread_affinity_core_list_type_error():
136    with pytest.raises(TypeError) as raise_info:
137        context = mslite.Context()
138        context.cpu.thread_affinity_core_list = 2
139    assert "cpu_thread_affinity_core_list must be list" in str(raise_info.value)
140
141
142def test_context_cpu_thread_affinity_core_list_element_type_error():
143    with pytest.raises(TypeError) as raise_info:
144        context = mslite.Context()
145        context.cpu.thread_affinity_core_list = ["1", "0"]
146    assert "cpu_thread_affinity_core_list element must be int" in str(raise_info.value)
147
148
149def test_context_cpu_thread_affinity_core_list():
150    context = mslite.Context()
151    context.cpu.thread_affinity_core_list = [2]
152    assert "thread_affinity_core_list: [2]" in str(context.cpu)
153    context.cpu.thread_affinity_core_list = [1, 0]
154    assert "thread_affinity_core_list: [1, 0]" in str(context.cpu)
155
156
157def test_context_gpu_precision_mode_type_error():
158    with pytest.raises(TypeError) as raise_info:
159        context = mslite.Context()
160        context.gpu.precision_mode = 1
161    assert "gpu_precision_mode must be str" in str(raise_info.value)
162
163
164def test_context_gpu_precision_mode_value_error():
165    with pytest.raises(ValueError) as raise_info:
166        context = mslite.Context()
167        context.gpu.precision_mode = "1"
168    assert "gpu_precision_mode must be in" in str(raise_info.value)
169
170
171def test_context_gpu_precision_mode():
172    context = mslite.Context()
173    context.gpu.precision_mode = "preferred_fp16"
174    assert "precision_mode: preferred_fp16" in str(context.gpu)
175
176
177def test_context_gpu_device_id_type_error():
178    with pytest.raises(TypeError) as raise_info:
179        context = mslite.Context()
180        context.gpu.device_id = "1"
181    assert "gpu_device_id must be int" in str(raise_info.value)
182
183
184def test_context_gpu_device_id_negative_error():
185    with pytest.raises(ValueError) as raise_info:
186        context = mslite.Context()
187        context.gpu.device_id = -1
188    assert "gpu_device_id must be a non-negative int" in str(raise_info.value)
189
190
191def test_context_gpu_device_id():
192    context = mslite.Context()
193    context.gpu.device_id = 1
194    assert "device_id: 1" in str(context.gpu)
195
196
197def test_context_ascend_precision_mode_value_error():
198    with pytest.raises(ValueError) as raise_info:
199        context = mslite.Context()
200        context.ascend.precision_mode = "1"
201    assert "ascend_precision_mode must be in" in str(raise_info.value)
202
203
204def test_context_ascend_precision_mode():
205    context = mslite.Context()
206    context.ascend.precision_mode = "enforce_fp32"
207    assert "precision_mode: enforce_fp32" in str(context.ascend)
208
209
210def test_context_ascend_device_id_type_error():
211    with pytest.raises(TypeError) as raise_info:
212        context = mslite.Context()
213        context.ascend.device_id = "1"
214    assert "ascend_device_id must be int" in str(raise_info.value)
215
216
217def test_context_ascend_device_id_negative_error():
218    with pytest.raises(ValueError) as raise_info:
219        context = mslite.Context()
220        context.ascend.device_id = -1
221    assert "ascend_device_id must be a non-negative int" in str(raise_info.value)
222
223
224def test_context_ascend_device_id():
225    context = mslite.Context()
226    context.ascend.device_id = 1
227    assert "device_id: 1" in str(context.ascend)
228
229
230def test_context_ascend_provider_type_error():
231    with pytest.raises(TypeError) as raise_info:
232        context = mslite.Context()
233        context.ascend.provider = 1
234    assert "ascend_provider must be str" in str(raise_info.value)
235
236
237def test_context_ascend_provider():
238    context = mslite.Context()
239    context.ascend.provider = "ge"
240    assert context.ascend.provider == "ge"
241    assert "provider: ge" in str(context.ascend)
242
243
244def test_context_ascend_rank_id_type_error():
245    with pytest.raises(TypeError) as raise_info:
246        context = mslite.Context()
247        context.ascend.rank_id = "1"
248    assert "ascend_rank_id must be int" in str(raise_info.value)
249
250
251def test_context_ascend_rank_id_negative_error():
252    with pytest.raises(ValueError) as raise_info:
253        context = mslite.Context()
254        context.ascend.rank_id = -1
255    assert "ascend_rank_id must be a non-negative int" in str(raise_info.value)
256
257
258def test_context_ascend_rank_id():
259    context = mslite.Context()
260    context.ascend.rank_id = 1
261    assert "rank_id: 1" in str(context.ascend)
262
263
264# ============================ Model ============================
265def test_model_01():
266    model = mslite.Model()
267    assert "model_path:" in str(model)
268
269
270def test_model_build_from_file_model_path_type_error():
271    with pytest.raises(TypeError) as raise_info:
272        model = mslite.Model()
273        model.build_from_file(model_path=1, model_type=mslite.ModelType.MINDIR_LITE)
274    assert "model_path must be str" in str(raise_info.value)
275
276
277def test_model_build_from_file_model_path_not_exist_error():
278    with pytest.raises(RuntimeError) as raise_info:
279        model = mslite.Model()
280        model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE)
281    assert "model_path does not exist" in str(raise_info.value)
282
283
284def test_model_build_from_file_model_type_type_error():
285    with pytest.raises(TypeError) as raise_info:
286        model = mslite.Model()
287        model.build_from_file(model_path="test.ms", model_type="MINDIR_LITE")
288    assert "model_type must be ModelType" in str(raise_info.value)
289
290
291def test_model_build_from_file_context_type_error():
292    with pytest.raises(TypeError) as raise_info:
293        cpu_device_info = mslite.Context().cpu
294        model = mslite.Model()
295        model.build_from_file(model_path="test.ms", model_type=mslite.ModelType.MINDIR_LITE, context=cpu_device_info)
296    assert "context must be Context" in str(raise_info.value)
297
298
299def test_model_build_from_file_config_path_type_error():
300    with pytest.raises(TypeError) as raise_info:
301        model = mslite.Model()
302        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
303                              config_path=1)
304    assert "config_path must be str" in str(raise_info.value)
305
306
307def test_model_build_from_file_config_path_not_exist_error():
308    with pytest.raises(RuntimeError) as raise_info:
309        model = mslite.Model()
310        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
311                              config_path="test.cfg")
312    assert "config_path does not exist" in str(raise_info.value)
313
314
315def test_model_build_from_file_config_dict_type_error():
316    with pytest.raises(TypeError) as raise_info:
317        model = mslite.Model()
318        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
319                              config_dict="test.cfg")
320    assert "config_dict must be dict" in str(raise_info.value)
321
322
323def test_model_build_from_file_config_dict_key_type_error():
324    with pytest.raises(TypeError) as raise_info:
325        model = mslite.Model()
326        dict_0 = {5: {"1": "2"}}
327        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
328                              config_dict=dict_0)
329    assert "config_dict_key must be str" in str(raise_info.value)
330
331
332def test_model_build_from_file_config_dict_value_type_error():
333    with pytest.raises(TypeError) as raise_info:
334        model = mslite.Model()
335        dict_1 = {"5": "6"}
336        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
337                              config_dict=dict_1)
338    assert "config_dict_value must be dict" in str(raise_info.value)
339
340
341def test_model_build_from_file_config_dict_value_key_type_error():
342    with pytest.raises(TypeError) as raise_info:
343        model = mslite.Model()
344        dict_2 = {"5": {3: "2"}}
345        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
346                              config_dict=dict_2)
347    assert "config_dict_value_key must be str" in str(raise_info.value)
348
349
350def test_model_build_from_file_config_dict_value_value_type_error():
351    with pytest.raises(TypeError) as raise_info:
352        model = mslite.Model()
353        dict_3 = {"5": {"1": 2}}
354        model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE,
355                              config_dict=dict_3)
356    assert "config_dict_value_value must be str" in str(raise_info.value)
357
358
359def get_model():
360    context = mslite.Context()
361    context.target = ["cpu"]
362    context.cpu.thread_num = 2
363    model = mslite.Model()
364    model.build_from_file(model_path="mobilenetv2.ms", model_type=mslite.ModelType.MINDIR_LITE, context=context)
365    return model
366
367
368def test_model_resize_inputs_type_error():
369    with pytest.raises(TypeError) as raise_info:
370        model = get_model()
371        inputs = model.get_inputs()
372        model.resize(inputs[0], [[1, 112, 112, 3]])
373    assert "inputs must be list" in str(raise_info.value)
374
375
376def test_model_resize_inputs_elements_type_error():
377    with pytest.raises(TypeError) as raise_info:
378        model = get_model()
379        model.resize([1, 2], [[1, 112, 112, 3]])
380    assert "inputs element must be Tensor" in str(raise_info.value)
381
382
383def test_model_resize_dims_type_error():
384    with pytest.raises(TypeError) as raise_info:
385        model = get_model()
386        inputs = model.get_inputs()
387        model.resize(inputs, "[[1, 112, 112, 3]]")
388    assert "dims must be list" in str(raise_info.value)
389
390
391def test_model_resize_dims_elements_type_error():
392    with pytest.raises(TypeError) as raise_info:
393        model = get_model()
394        inputs = model.get_inputs()
395        model.resize(inputs, ["[1, 112, 112, 3]"])
396    assert "dims element must be list" in str(raise_info.value)
397
398
399def test_model_resize_dims_elements_elements_type_error():
400    with pytest.raises(TypeError) as raise_info:
401        model = get_model()
402        inputs = model.get_inputs()
403        model.resize(inputs, [[1, "112", 112, 3]])
404    assert "dims element's element must be int" in str(raise_info.value)
405
406
407def test_model_resize_inputs_size_not_equal_dims_size_error():
408    with pytest.raises(ValueError) as raise_info:
409        model = get_model()
410        inputs = model.get_inputs()
411        model.resize(inputs, [[1, 112, 112, 3], [1, 112, 112, 3]])
412    assert "inputs' size does not match dims' size" in str(raise_info.value)
413
414
415def test_model_resize_01():
416    model = get_model()
417    inputs = model.get_inputs()
418    assert inputs[0].shape == [1, 224, 224, 3]
419    model.resize(inputs, [[1, 112, 112, 3]])
420    assert inputs[0].shape == [1, 112, 112, 3]
421
422
423def test_model_predict_inputs_type_error():
424    with pytest.raises(TypeError) as raise_info:
425        model = get_model()
426        inputs = model.get_inputs()
427        outputs = model.predict(inputs[0])
428    assert "inputs must be list" in str(raise_info.value)
429
430
431def test_model_predict_inputs_element_type_error():
432    with pytest.raises(TypeError) as raise_info:
433        model = get_model()
434        outputs = model.predict(["input"])
435    assert "inputs element must be Tensor" in str(raise_info.value)
436
437def test_model_get_model_info_type_error():
438    with pytest.raises(TypeError) as raise_info:
439        model = get_model()
440        inputs = model.get_model_info()
441    assert "key must be str" in str(raise_info.value)
442
443def test_model_predict_runtime_error():
444    with pytest.raises(RuntimeError) as raise_info:
445        model = get_model()
446        inputs = model.get_inputs()
447        outputs = model.predict(inputs)
448    assert "predict failed" in str(raise_info.value)
449
450
451def test_model_predict_01():
452    model = get_model()
453    inputs = model.get_inputs()
454    in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
455    inputs[0].set_data_from_numpy(in_data)
456    outputs = model.predict(inputs)
457
458
459def test_model_predict_02():
460    model = get_model()
461    inputs = model.get_inputs()
462    input_tensor = mslite.Tensor()
463    input_tensor.dtype = inputs[0].dtype
464    input_tensor.shape = inputs[0].shape
465    input_tensor.format = inputs[0].format
466    input_tensor.name = inputs[0].name
467    in_data = np.arange(1 * 224 * 224 * 3, dtype=np.float32).reshape((1, 224, 224, 3))
468    input_tensor.set_data_from_numpy(in_data)
469    outputs = model.predict([input_tensor])
470
471
472# ============================ Tensor ============================
473def test_tensor_type_error():
474    tensor1 = mslite.Tensor()
475    tensor2 = mslite.Tensor(tensor=tensor1)  # now supported
476
477
478def test_tensor():
479    tensor1 = mslite.Tensor()
480    assert tensor1.name == ""
481
482
483def test_tensor_name_type_error():
484    with pytest.raises(TypeError) as raise_info:
485        tensor = mslite.Tensor()
486        tensor.name = 1
487    assert "name must be str" in str(raise_info.value)
488
489
490def test_tensor_name():
491    tensor = mslite.Tensor()
492    tensor.name = "tensor0"
493    assert tensor.name == "tensor0"
494
495
496def test_tensor_dtype_type_error():
497    with pytest.raises(TypeError) as raise_info:
498        tensor = mslite.Tensor()
499        tensor.dtype = 1
500    assert "dtype must be DataType" in str(raise_info.value)
501
502
503def test_tensor_dtype():
504    tensor = mslite.Tensor()
505    tensor.dtype = mslite.DataType.INT32
506    assert tensor.dtype == mslite.DataType.INT32
507
508
509def test_tensor_shape_type_error():
510    with pytest.raises(TypeError) as raise_info:
511        tensor = mslite.Tensor()
512        tensor.shape = 224
513    assert "shape must be list" in str(raise_info.value)
514
515
516def test_tensor_shape_element_type_error():
517    with pytest.raises(TypeError) as raise_info:
518        tensor = mslite.Tensor()
519        tensor.shape = ["224", "224"]
520    assert "shape element must be int" in str(raise_info.value)
521
522
523def test_tensor_shape_get_element_num_get_data_size_01():
524    tensor = mslite.Tensor()
525    tensor.dtype = mslite.DataType.FLOAT32
526    tensor.shape = [16, 16]
527    assert tensor.shape == [16, 16]
528    assert tensor.element_num == 256
529    assert tensor.data_size == 1024
530
531
532def test_tensor_format_type_error():
533    with pytest.raises(TypeError) as raise_info:
534        tensor = mslite.Tensor()
535        tensor.format = 1
536    assert "format must be Format" in str(raise_info.value)
537
538
539def test_tensor_format():
540    tensor = mslite.Tensor()
541    tensor.format = mslite.Format.NHWC4
542    assert tensor.format == mslite.Format.NHWC4
543
544
545def test_tensor_set_data_from_numpy_numpy_obj_type_error():
546    with pytest.raises(TypeError) as raise_info:
547        tensor = mslite.Tensor()
548        tensor.set_data_from_numpy(1)
549    assert "numpy_obj must be numpy.ndarray," in str(raise_info.value)
550
551
552def test_tensor_set_data_from_numpy_data_type_not_equal_error():
553    with pytest.raises(RuntimeError) as raise_info:
554        tensor = mslite.Tensor()
555        tensor.dtype = mslite.DataType.FLOAT32
556        tensor.shape = [2, 3]
557        in_data = np.arange(2 * 3, dtype=np.int32).reshape((2, 3))
558        tensor.set_data_from_numpy(in_data)
559    assert "data type not equal" in str(raise_info.value)
560
561
562def test_tensor_set_data_from_numpy_data_size_not_equal_error():
563    with pytest.raises(RuntimeError) as raise_info:
564        tensor = mslite.Tensor()
565        tensor.dtype = mslite.DataType.FLOAT32
566        in_data = np.arange(2 * 3, dtype=np.float32).reshape((2, 3))
567        tensor.set_data_from_numpy(in_data)
568    assert "data size not equal" in str(raise_info.value)
569
570
571def test_tensor_set_data_from_numpy():
572    tensor = mslite.Tensor()
573    tensor.dtype = mslite.DataType.FLOAT32
574    tensor.shape = [2, 3]
575    in_data = np.arange(2 * 3, dtype=np.float32).reshape((2, 3))
576    tensor.set_data_from_numpy(in_data)
577    out_data = tensor.get_data_to_numpy()
578    assert (out_data == in_data).all()
579
580
581def test_model_group_invalid_flags_error():
582    with pytest.raises(RuntimeError) as raise_info:
583        _ = mslite.ModelGroup(flags=1001)
584    assert "Parameter flags should be ModelGroupFlag.SHARE_WORKSPACE or" in str(raise_info.value)
585
586
587def test_model_group_add_model_share_workspace_add_model_obj_error():
588    with pytest.raises(RuntimeError) as raise_info:
589        model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WORKSPACE)
590        model0 = mslite.Model()
591        model1 = mslite.Model()
592        model_group.add_model([model0, model1])
593    assert "ModelGroup's add model failed." in str(raise_info.value)
594
595
596def test_model_group_add_model_share_weight_add_model_path_error():
597    with pytest.raises(RuntimeError) as raise_info:
598        model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT)
599        model_group.add_model(["model0_path", "model1_path"])
600    assert "ModelGroup's add model failed." in str(raise_info.value)
601
602
603def test_model_group_add_model_invalid_model_path_with_model_obj_error():
604    with pytest.raises(TypeError) as raise_info:
605        model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT)
606        model1 = mslite.Model()
607        model_group.add_model(["model_path", model1])
608    assert "models element must be all str or Model" in str(raise_info.value)
609
610
611def test_model_group_add_model_invalid_model_obj_with_model_path_error():
612    with pytest.raises(TypeError) as raise_info:
613        model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT)
614        model1 = mslite.Model()
615        model_group.add_model([model1, "model_path"])
616    assert "models element must be all str or Model" in str(raise_info.value)
617
618
619def test_model_group_add_model_invalid_model_obj_type_error():
620    with pytest.raises(TypeError) as raise_info:
621        model_group = mslite.ModelGroup(flags=mslite.ModelGroupFlag.SHARE_WEIGHT)
622        model_group.add_model("model_path")
623    assert "models must be list/tuple, but got" in str(raise_info.value)
624