1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This tool analyzes a TensorFlow Lite graph.""" 16 17import http.server 18import os 19 20# pylint: disable=g-import-not-at-top 21if not os.path.splitext(__file__)[0].endswith( 22 os.path.join("tflite_runtime", "analyzer")): 23 # This file is part of tensorflow package. 24 from tensorflow.lite.python.analyzer_wrapper import _pywrap_analyzer_wrapper as _analyzer_wrapper 25else: 26 # This file is part of tflite_runtime package. 27 from tflite_runtime import _pywrap_analyzer_wrapper as _analyzer_wrapper 28 29 30def _handle_webserver(host_name, server_port, html_body): 31 """Start a HTTP server for the given html_body.""" 32 33 class MyServer(http.server.BaseHTTPRequestHandler): 34 35 def do_GET(self): # pylint: disable=invalid-name 36 self.send_response(200) 37 self.send_header("Content-type", "text/html") 38 self.end_headers() 39 self.wfile.write(bytes(html_body, "utf-8")) 40 41 web_server = http.server.HTTPServer((host_name, server_port), MyServer) 42 print("Server started http://%s:%s" % (host_name, server_port)) 43 try: 44 web_server.serve_forever() 45 except KeyboardInterrupt: 46 pass 47 web_server.server_close() 48 49 50class ModelAnalyzer: 51 """Provides a collection of TFLite model analyzer tools.""" 52 53 @staticmethod 54 def analyze(model_path=None, 55 model_content=None, 56 experimental_use_mlir=False, 57 gpu_compatibility=False): 58 """Analyzes the given tflite_model. 59 60 Args: 61 model_path: TFLite flatbuffer model path. 62 model_content: TFLite flatbuffer model object. 63 experimental_use_mlir: Use MLIR format for model dump. 64 gpu_compatibility: Whether to check GPU delegate compatibility. 65 66 Returns: 67 Print analyzed report via console output. 68 """ 69 if not model_path and not model_content: 70 raise ValueError("neither `model_path` nor `model_content` is provided") 71 if model_path: 72 print(f"=== {model_path} ===\n") 73 tflite_model = model_path 74 input_is_filepath = True 75 else: 76 print("=== TFLite ModelAnalyzer ===\n") 77 tflite_model = model_content 78 input_is_filepath = False 79 80 if experimental_use_mlir: 81 print(_analyzer_wrapper.FlatBufferToMlir(tflite_model, input_is_filepath)) 82 else: 83 print( 84 _analyzer_wrapper.ModelAnalyzer(tflite_model, input_is_filepath, 85 gpu_compatibility)) 86