• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4import os
5import pytest
6import cv2
7import numpy as np
8
9from context import network_executor
10from context import network_executor_tflite
11from context import cv_utils
12
13@pytest.mark.parametrize("executor_name", ["armnn", "tflite"])
14def test_execute_network(test_data_folder, executor_name):
15    model_path = os.path.join(test_data_folder, "ssd_mobilenet_v1.tflite")
16    backends = ["CpuAcc", "CpuRef"]
17    if executor_name == "armnn":
18        executor = network_executor.ArmnnNetworkExecutor(model_path, backends)
19    elif executor_name == "tflite":
20        delegate_path = os.path.join(test_data_folder, "libarmnnDelegate.so")
21        executor = network_executor_tflite.TFLiteNetworkExecutor(model_path, backends, delegate_path)
22    else:
23        raise f"unsupported executor_name: {executor_name}"
24
25    img = cv2.imread(os.path.join(test_data_folder, "messi5.jpg"))
26    resized_img = cv_utils.preprocess(img, executor.get_data_type(), executor.get_shape(), True)
27
28    output_result = executor.run([resized_img])
29
30    # Ensure it detects a person
31    classes = output_result[1]
32    assert classes[0][0] == 0
33
34    # Unit tests for network executor class functions - specifically for ssd_mobilenet_v1.tflite network
35    assert executor.get_data_type() == np.uint8
36    assert executor.get_shape() == (1, 300, 300, 3)
37