1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A simple script for inspect checkpoint files (deprecated).""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import sys 23 24from tensorflow.contrib.framework.python.framework import checkpoint_utils 25from tensorflow.python.platform import app 26 27FLAGS = None 28 29 30def print_tensors_in_checkpoint_file(file_name, tensor_name): 31 """Prints tensors in a checkpoint file. 32 33 If no `tensor_name` is provided, prints the tensor names and shapes 34 in the checkpoint file. 35 36 If `tensor_name` is provided, prints the content of the tensor. 37 38 Args: 39 file_name: Name of the checkpoint file. 40 tensor_name: Name of the tensor in the checkpoint file to print. 41 """ 42 try: 43 if not tensor_name: 44 variables = checkpoint_utils.list_variables(file_name) 45 for name, shape in variables: 46 print("%s\t%s" % (name, str(shape))) 47 else: 48 print("tensor_name: ", tensor_name) 49 print(checkpoint_utils.load_variable(file_name, tensor_name)) 50 except Exception as e: # pylint: disable=broad-except 51 print(str(e)) 52 if "corrupted compressed block contents" in str(e): 53 print("It's likely that your checkpoint file has been compressed " 54 "with SNAPPY.") 55 56 57def main(unused_argv): 58 if not FLAGS.file_name: 59 print("Usage: inspect_checkpoint --file_name=<checkpoint_file_name " 60 "or directory> [--tensor_name=tensor_to_print]") 61 sys.exit(1) 62 else: 63 print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name) 64 65 66if __name__ == "__main__": 67 parser = argparse.ArgumentParser() 68 parser.register("type", "bool", lambda v: v.lower() == "true") 69 parser.add_argument( 70 "--file_name", 71 type=str, 72 default="", 73 help="Checkpoint filename" 74 ) 75 parser.add_argument( 76 "--tensor_name", 77 type=str, 78 default="", 79 help="Name of the tensor to inspect" 80 ) 81 FLAGS, unparsed = parser.parse_known_args() 82 app.run(main=main, argv=[sys.argv[0]] + unparsed) 83