• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This tool analyzes a TensorFlow Lite graph."""
16
17import http.server
18import os
19
20# pylint: disable=g-import-not-at-top
21if not os.path.splitext(__file__)[0].endswith(
22    os.path.join("tflite_runtime", "analyzer")):
23  # This file is part of tensorflow package.
24  from tensorflow.lite.python.analyzer_wrapper import _pywrap_analyzer_wrapper as _analyzer_wrapper
25else:
26  # This file is part of tflite_runtime package.
27  from tflite_runtime import _pywrap_analyzer_wrapper as _analyzer_wrapper
28
29
30def _handle_webserver(host_name, server_port, html_body):
31  """Start a HTTP server for the given html_body."""
32
33  class MyServer(http.server.BaseHTTPRequestHandler):
34
35    def do_GET(self):  # pylint: disable=invalid-name
36      self.send_response(200)
37      self.send_header("Content-type", "text/html")
38      self.end_headers()
39      self.wfile.write(bytes(html_body, "utf-8"))
40
41  web_server = http.server.HTTPServer((host_name, server_port), MyServer)
42  print("Server started http://%s:%s" % (host_name, server_port))
43  try:
44    web_server.serve_forever()
45  except KeyboardInterrupt:
46    pass
47  web_server.server_close()
48
49
50class ModelAnalyzer:
51  """Provides a collection of TFLite model analyzer tools."""
52
53  @staticmethod
54  def analyze(model_path=None,
55              model_content=None,
56              experimental_use_mlir=False,
57              gpu_compatibility=False):
58    """Analyzes the given tflite_model.
59
60    Args:
61      model_path: TFLite flatbuffer model path.
62      model_content: TFLite flatbuffer model object.
63      experimental_use_mlir: Use MLIR format for model dump.
64      gpu_compatibility: Whether to check GPU delegate compatibility.
65
66    Returns:
67      Print analyzed report via console output.
68    """
69    if not model_path and not model_content:
70      raise ValueError("neither `model_path` nor `model_content` is provided")
71    if model_path:
72      print(f"=== {model_path} ===\n")
73      tflite_model = model_path
74      input_is_filepath = True
75    else:
76      print("=== TFLite ModelAnalyzer ===\n")
77      tflite_model = model_content
78      input_is_filepath = False
79
80    if experimental_use_mlir:
81      print(_analyzer_wrapper.FlatBufferToMlir(tflite_model, input_is_filepath))
82    else:
83      print(
84          _analyzer_wrapper.ModelAnalyzer(tflite_model, input_is_filepath,
85                                          gpu_compatibility))
86