1#!/usr/bin/env python3 2# Copyright © 2020 NXP and Contributors. All rights reserved. 3# SPDX-License-Identifier: MIT 4 5import example_utils as eu 6import os 7 8if __name__ == "__main__": 9 args = eu.parse_command_line() 10 11 # names of the files in the archive 12 labels_filename = 'labels_mobilenet_quant_v1_224.txt' 13 model_filename = 'mobilenet_v1_1.0_224_quant.tflite' 14 archive_filename = 'mobilenet_v1_1.0_224_quant_and_labels.zip' 15 16 archive_url = \ 17 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_quant_and_labels.zip' 18 19 model_filename, labels_filename = eu.get_model_and_labels(args.model_dir, model_filename, labels_filename, 20 archive_filename, archive_url) 21 22 image_filenames = eu.get_images(args.data_dir) 23 24 # all 3 resources must exist to proceed further 25 assert os.path.exists(labels_filename) 26 assert os.path.exists(model_filename) 27 assert image_filenames 28 for im in image_filenames: 29 assert(os.path.exists(im)) 30 31 # Create a network from the model file 32 net_id, graph_id, parser, runtime = eu.create_tflite_network(model_filename) 33 34 # Load input information from the model 35 # tflite has all the need information in the model unlike other formats 36 input_names = parser.GetSubgraphInputTensorNames(graph_id) 37 assert len(input_names) == 1 # there should be 1 input tensor in mobilenet 38 39 input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0]) 40 input_width = input_binding_info[1].GetShape()[1] 41 input_height = input_binding_info[1].GetShape()[2] 42 43 # Load output information from the model and create output tensors 44 output_names = parser.GetSubgraphOutputTensorNames(graph_id) 45 assert len(output_names) == 1 # and only one output tensor 46 output_binding_info = parser.GetNetworkOutputBindingInfo(graph_id, output_names[0]) 47 48 # Load labels file 49 labels = eu.load_labels(labels_filename) 50 51 # Load images and resize to expected size 52 images = eu.load_images(image_filenames, input_width, input_height) 53 54 eu.run_inference(runtime, net_id, images, labels, input_binding_info, output_binding_info) 55