1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3 4import os 5import pytest 6import cv2 7import numpy as np 8 9from context import network_executor 10from context import network_executor_tflite 11from context import cv_utils 12 13@pytest.mark.parametrize("executor_name", ["armnn", "tflite"]) 14def test_execute_network(test_data_folder, executor_name): 15 model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite") 16 backends = ["CpuAcc", "CpuRef"] 17 if executor_name == "armnn": 18 executor = network_executor.ArmnnNetworkExecutor(model_path, backends) 19 elif executor_name == "tflite": 20 delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so") 21 executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path) 22 else: 23 raise f"unsupported executor_name: {executor_name}" 24 25 img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg")) 26 resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True) 27 28 output_result = executor.run([resized_img]) 29 30 # Ensure it detects a person 31 classes = output_result[1] 32 assert classes[0][0] == 0 33 34 # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network 35 assert executor.get_data_type() == np.uint8 36 assert executor.get_shape() == (1, 300, 300, 3) 37