1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""label_image for tflite""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import numpy as np 23 24from PIL import Image 25 26from tensorflow.lite.python import interpreter as interpreter_wrapper 27 28def load_labels(filename): 29 my_labels = [] 30 input_file = open(filename, 'r') 31 for l in input_file: 32 my_labels.append(l.strip()) 33 return my_labels 34 35if __name__ == "__main__": 36 floating_model = False 37 38 parser = argparse.ArgumentParser() 39 parser.add_argument("-i", "--image", default="/tmp/grace_hopper.bmp", \ 40 help="image to be classified") 41 parser.add_argument("-m", "--model_file", \ 42 default="/tmp/mobilenet_v1_1.0_224_quant.tflite", \ 43 help=".tflite model to be executed") 44 parser.add_argument("-l", "--label_file", default="/tmp/labels.txt", \ 45 help="name of file containing labels") 46 parser.add_argument("--input_mean", default=127.5, help="input_mean") 47 parser.add_argument("--input_std", default=127.5, \ 48 help="input standard deviation") 49 args = parser.parse_args() 50 51 interpreter = interpreter_wrapper.Interpreter(model_path=args.model_file) 52 interpreter.allocate_tensors() 53 54 input_details = interpreter.get_input_details() 55 output_details = interpreter.get_output_details() 56 57 # check the type of the input tensor 58 if input_details[0]['dtype'] == np.float32: 59 floating_model = True 60 61 # NxHxWxC, H:1, W:2 62 height = input_details[0]['shape'][1] 63 width = input_details[0]['shape'][2] 64 img = Image.open(args.image) 65 img = img.resize((width, height)) 66 67 # add N dim 68 input_data = np.expand_dims(img, axis=0) 69 70 if floating_model: 71 input_data = (np.float32(input_data) - args.input_mean) / args.input_std 72 73 interpreter.set_tensor(input_details[0]['index'], input_data) 74 75 interpreter.invoke() 76 77 output_data = interpreter.get_tensor(output_details[0]['index']) 78 results = np.squeeze(output_data) 79 80 top_k = results.argsort()[-5:][::-1] 81 labels = load_labels(args.label_file) 82 for i in top_k: 83 if floating_model: 84 print('{0:08.6f}'.format(float(results[i]))+":", labels[i]) 85 else: 86 print('{0:08.6f}'.format(float(results[i]/255.0))+":", labels[i]) 87