1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ================================ 15"""Imports a protobuf model as a graph in Tensorboard.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import sys 23 24from tensorflow.core.framework import graph_pb2 25from tensorflow.python.client import session 26from tensorflow.python.framework import importer 27from tensorflow.python.framework import ops 28from tensorflow.python.platform import app 29from tensorflow.python.platform import gfile 30from tensorflow.python.summary import summary 31 32# Try importing TensorRT ops if available 33# TODO(aaroey): ideally we should import everything from contrib, but currently 34# tensorrt module would cause build errors when being imported in 35# tensorflow/contrib/__init__.py. Fix it. 36# pylint: disable=unused-import,g-import-not-at-top,wildcard-import 37try: 38 from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * 39except ImportError: 40 pass 41# pylint: enable=unused-import,g-import-not-at-top,wildcard-import 42 43def import_to_tensorboard(model_dir, log_dir): 44 """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. 45 46 Args: 47 model_dir: The location of the protobuf (`pb`) model to visualize 48 log_dir: The location for the Tensorboard log to begin visualization from. 49 50 Usage: 51 Call this function with your model location and desired log directory. 52 Launch Tensorboard by pointing it to the log directory. 53 View your imported `.pb` model as a graph. 54 """ 55 with session.Session(graph=ops.Graph()) as sess: 56 with gfile.GFile(model_dir, "rb") as f: 57 graph_def = graph_pb2.GraphDef() 58 graph_def.ParseFromString(f.read()) 59 importer.import_graph_def(graph_def) 60 61 pb_visual_writer = summary.FileWriter(log_dir) 62 pb_visual_writer.add_graph(sess.graph) 63 print("Model Imported. Visualize by running: " 64 "tensorboard --logdir={}".format(log_dir)) 65 66 67def main(unused_args): 68 import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir) 69 70if __name__ == "__main__": 71 parser = argparse.ArgumentParser() 72 parser.register("type", "bool", lambda v: v.lower() == "true") 73 parser.add_argument( 74 "--model_dir", 75 type=str, 76 default="", 77 required=True, 78 help="The location of the protobuf (\'pb\') model to visualize.") 79 parser.add_argument( 80 "--log_dir", 81 type=str, 82 default="", 83 required=True, 84 help="The location for the Tensorboard log to begin visualization from.") 85 FLAGS, unparsed = parser.parse_known_args() 86 app.run(main=main, argv=[sys.argv[0]] + unparsed) 87