1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test lite python API. 17""" 18import mindspore_lite as mslite 19import numpy as np 20 21 22def common_predict(context, model_path, in_data_path): 23 model = mslite.Model() 24 model.build_from_file(model_path, mslite.ModelType.MINDIR_LITE, context) 25 26 inputs = model.get_inputs() 27 in_data = np.fromfile(in_data_path, dtype=np.float32) 28 inputs[0].set_data_from_numpy(in_data) 29 outputs = model.predict(inputs) 30 for output in outputs: 31 data = output.get_data_to_numpy() 32 print("data: ", data) 33 34 35# ============================ cpu inference ============================ 36def test_cpu_inference_01(): 37 context = mslite.Context() 38 context.target = ["cpu"] 39 context.cpu.thread_num = 1 40 context.cpu.thread_affinity_mode = 2 41 cpu_model_path = "mobilenetv2.ms" 42 cpu_in_data_path = "mobilenetv2.ms.bin" 43 common_predict(context, cpu_model_path, cpu_in_data_path) 44 45 46# ============================ gpu inference ============================ 47def test_gpu_inference_01(): 48 context = mslite.Context() 49 context.target = ["gpu"] 50 context.gpu.device_id = 0 51 print("gpu: ", context.gpu) 52 context.cpu.thread_num = 1 53 context.cpu.thread_affinity_mode = 2 54 print("cpu_backup: ", context.cpu) 55 gpu_model_path = "mobilenetv2.ms" 56 gpu_in_data_path = "mobilenetv2.ms.bin" 57 common_predict(context, gpu_model_path, gpu_in_data_path) 58 59 60# ============================ ascend inference ============================ 61def test_ascend_inference_01(): 62 context = mslite.Context() 63 context.target = ["ascend"] 64 context.ascend.device_id = 0 65 print("ascend: ", context.ascend) 66 context.cpu.thread_num = 1 67 context.cpu.thread_affinity_mode = 2 68 print("cpu_backup: ", context.cpu) 69 ascend_model_path = "mnist.tflite.ms" 70 ascend_in_data_path = "mnist.tflite.ms.bin" 71 common_predict(context, ascend_model_path, ascend_in_data_path) 72 73 74# ============================ server inference ============================ 75def test_server_inference_01(): 76 context = mslite.Context() 77 context.target = ["cpu"] 78 context.cpu.thread_num = 4 79 context.parallel.workers_num = 1 80 model_parallel_runner = mslite.ModelParallelRunner() 81 cpu_model_path = "mobilenetv2.ms" 82 cpu_in_data_path = "mobilenetv2.ms.bin" 83 model_parallel_runner.build_from_file(model_path=cpu_model_path, context=context) 84 85 inputs = model_parallel_runner.get_inputs() 86 in_data = np.fromfile(cpu_in_data_path, dtype=np.float32) 87 inputs[0].set_data_from_numpy(in_data) 88 outputs = model_parallel_runner.predict(inputs) 89 data = outputs[0].get_data_to_numpy() 90 print("data: ", data) 91