• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ================================
15"""Imports a protobuf model as a graph in Tensorboard."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import sys
23
24from tensorflow.core.framework import graph_pb2
25from tensorflow.python.client import session
26from tensorflow.python.framework import importer
27from tensorflow.python.framework import ops
28from tensorflow.python.platform import app
29from tensorflow.python.platform import gfile
30from tensorflow.python.summary import summary
31
32# Try importing TensorRT ops if available
33# TODO(aaroey): ideally we should import everything from contrib, but currently
34# tensorrt module would cause build errors when being imported in
35# tensorflow/contrib/__init__.py. Fix it.
36# pylint: disable=unused-import,g-import-not-at-top,wildcard-import
37try:
38  from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
39except ImportError:
40  pass
41# pylint: enable=unused-import,g-import-not-at-top,wildcard-import
42
43def import_to_tensorboard(model_dir, log_dir):
44  """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
45
46  Args:
47    model_dir: The location of the protobuf (`pb`) model to visualize
48    log_dir: The location for the Tensorboard log to begin visualization from.
49
50  Usage:
51    Call this function with your model location and desired log directory.
52    Launch Tensorboard by pointing it to the log directory.
53    View your imported `.pb` model as a graph.
54  """
55  with session.Session(graph=ops.Graph()) as sess:
56    with gfile.GFile(model_dir, "rb") as f:
57      graph_def = graph_pb2.GraphDef()
58      graph_def.ParseFromString(f.read())
59      importer.import_graph_def(graph_def)
60
61    pb_visual_writer = summary.FileWriter(log_dir)
62    pb_visual_writer.add_graph(sess.graph)
63    print("Model Imported. Visualize by running: "
64          "tensorboard --logdir={}".format(log_dir))
65
66
67def main(unused_args):
68  import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir)
69
70if __name__ == "__main__":
71  parser = argparse.ArgumentParser()
72  parser.register("type", "bool", lambda v: v.lower() == "true")
73  parser.add_argument(
74      "--model_dir",
75      type=str,
76      default="",
77      required=True,
78      help="The location of the protobuf (\'pb\') model to visualize.")
79  parser.add_argument(
80      "--log_dir",
81      type=str,
82      default="",
83      required=True,
84      help="The location for the Tensorboard log to begin visualization from.")
85  FLAGS, unparsed = parser.parse_known_args()
86  app.run(main=main, argv=[sys.argv[0]] + unparsed)
87