1# Copyright (c) 2021 Arm Limited. 2# 3# SPDX-License-Identifier: MIT 4# 5# Permission is hereby granted, free of charge, to any person obtaining a copy 6# of this software and associated documentation files (the "Software"), to 7# deal in the Software without restriction, including without limitation the 8# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 9# sell copies of the Software, and to permit persons to whom the Software is 10# furnished to do so, subject to the following conditions: 11# 12# The above copyright notice and this permission notice shall be included in all 13# copies or substantial portions of the Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21# SOFTWARE. 22import logging 23import os 24 25 26def is_tflite_model(model_path): 27 """Check if a model is of TFLite type 28 29 Parameters: 30 ---------- 31 model_path: str 32 Path to model 33 34 Returns 35 ---------- 36 bool: 37 True if given path is a valid TFLite model 38 """ 39 40 try: 41 with open(model_path, "rb") as f: 42 hdr_bytes = f.read(8) 43 hdr_str = hdr_bytes[4:].decode("utf-8") 44 if hdr_str == "TFL3": 45 return True 46 else: 47 return False 48 except: 49 return False 50 51 52def identify_model_type(model_path): 53 """Identify the type of a given deep learning model 54 55 Parameters: 56 ---------- 57 model_path: str 58 Path to model 59 60 Returns 61 ---------- 62 model_type: str 63 String representation of model type or 'None' if type could not be retrieved. 64 """ 65 66 if not os.path.exists(model_path): 67 logging.warn(f"Provided model {model_path} does not exist!") 68 return None 69 70 if is_tflite_model(model_path): 71 model_type = "tflite" 72 else: 73 logging.warn(logging.warn(f"Provided model {model_path} is not of supported type!")) 74 model_type = None 75 76 return model_type 77