1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A simple script for inspect checkpoint files.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import re 22import sys 23 24import numpy as np 25 26from tensorflow.python.platform import app 27from tensorflow.python.platform import flags 28from tensorflow.python.training import py_checkpoint_reader 29 30FLAGS = None 31 32 33def _count_total_params(reader, count_exclude_pattern=""): 34 """Count total number of variables.""" 35 var_to_shape_map = reader.get_variable_to_shape_map() 36 37 # Filter out tensors that we don't want to count 38 if count_exclude_pattern: 39 regex_pattern = re.compile(count_exclude_pattern) 40 new_var_to_shape_map = {} 41 exclude_num_tensors = 0 42 exclude_num_params = 0 43 for v in var_to_shape_map: 44 if regex_pattern.search(v): 45 exclude_num_tensors += 1 46 exclude_num_params += np.prod(var_to_shape_map[v]) 47 else: 48 new_var_to_shape_map[v] = var_to_shape_map[v] 49 var_to_shape_map = new_var_to_shape_map 50 print("# Excluding %d tensors (%d params) that match %s when counting." % ( 51 exclude_num_tensors, exclude_num_params, count_exclude_pattern)) 52 53 var_sizes = [np.prod(var_to_shape_map[v]) for v in var_to_shape_map] 54 return np.sum(var_sizes, dtype=int) 55 56 57def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, 58 all_tensor_names=False, 59 count_exclude_pattern=""): 60 """Prints tensors in a checkpoint file. 61 62 If no `tensor_name` is provided, prints the tensor names and shapes 63 in the checkpoint file. 64 65 If `tensor_name` is provided, prints the content of the tensor. 66 67 Args: 68 file_name: Name of the checkpoint file. 69 tensor_name: Name of the tensor in the checkpoint file to print. 70 all_tensors: Boolean indicating whether to print all tensors. 71 all_tensor_names: Boolean indicating whether to print all tensor names. 72 count_exclude_pattern: Regex string, pattern to exclude tensors when count. 73 """ 74 try: 75 reader = py_checkpoint_reader.NewCheckpointReader(file_name) 76 if all_tensors or all_tensor_names: 77 var_to_shape_map = reader.get_variable_to_shape_map() 78 var_to_dtype_map = reader.get_variable_to_dtype_map() 79 for key, value in sorted(var_to_shape_map.items()): 80 print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value)) 81 if all_tensors: 82 print(reader.get_tensor(key)) 83 elif not tensor_name: 84 print(reader.debug_string().decode("utf-8", errors="ignore")) 85 else: 86 if not reader.has_tensor(tensor_name): 87 print("Tensor %s not found in checkpoint" % tensor_name) 88 return 89 90 var_to_shape_map = reader.get_variable_to_shape_map() 91 var_to_dtype_map = reader.get_variable_to_dtype_map() 92 print("tensor: %s (%s) %s" % 93 (tensor_name, var_to_dtype_map[tensor_name].name, 94 var_to_shape_map[tensor_name])) 95 print(reader.get_tensor(tensor_name)) 96 97 # Count total number of parameters 98 print("# Total number of params: %d" % _count_total_params( 99 reader, count_exclude_pattern=count_exclude_pattern)) 100 except Exception as e: # pylint: disable=broad-except 101 print(str(e)) 102 if "corrupted compressed block contents" in str(e): 103 print("It's likely that your checkpoint file has been compressed " 104 "with SNAPPY.") 105 if ("Data loss" in str(e) and 106 any(e in file_name for e in [".index", ".meta", ".data"])): 107 proposed_file = ".".join(file_name.split(".")[0:-1]) 108 v2_file_error_template = """ 109It's likely that this is a V2 checkpoint and you need to provide the filename 110*prefix*. Try removing the '.' and extension. Try: 111inspect checkpoint --file_name = {}""" 112 print(v2_file_error_template.format(proposed_file)) 113 114 115def parse_numpy_printoption(kv_str): 116 """Sets a single numpy printoption from a string of the form 'x=y'. 117 118 See documentation on numpy.set_printoptions() for details about what values 119 x and y can take. x can be any option listed there other than 'formatter'. 120 121 Args: 122 kv_str: A string of the form 'x=y', such as 'threshold=100000' 123 124 Raises: 125 argparse.ArgumentTypeError: If the string couldn't be used to set any 126 nump printoption. 127 """ 128 k_v_str = kv_str.split("=", 1) 129 if len(k_v_str) != 2 or not k_v_str[0]: 130 raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str) 131 k, v_str = k_v_str 132 printoptions = np.get_printoptions() 133 if k not in printoptions: 134 raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k) 135 v_type = type(printoptions[k]) 136 if v_type is type(None): 137 raise argparse.ArgumentTypeError( 138 "Setting '%s' from the command line is not supported." % k) 139 try: 140 v = ( 141 v_type(v_str) 142 if v_type is not bool else flags.BooleanParser().parse(v_str)) 143 except ValueError as e: 144 raise argparse.ArgumentTypeError(e.message) 145 np.set_printoptions(**{k: v}) 146 147 148def main(unused_argv): 149 if not FLAGS.file_name: 150 print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " 151 "[--tensor_name=tensor_to_print] " 152 "[--all_tensors] " 153 "[--all_tensor_names] " 154 "[--printoptions]") 155 sys.exit(1) 156 else: 157 print_tensors_in_checkpoint_file( 158 FLAGS.file_name, FLAGS.tensor_name, 159 FLAGS.all_tensors, FLAGS.all_tensor_names, 160 count_exclude_pattern=FLAGS.count_exclude_pattern) 161 162 163if __name__ == "__main__": 164 parser = argparse.ArgumentParser() 165 parser.register("type", "bool", lambda v: v.lower() == "true") 166 parser.add_argument( 167 "--file_name", 168 type=str, 169 default="", 170 help="Checkpoint filename. " 171 "Note, if using Checkpoint V2 format, file_name is the " 172 "shared prefix between all files in the checkpoint.") 173 parser.add_argument( 174 "--tensor_name", 175 type=str, 176 default="", 177 help="Name of the tensor to inspect") 178 parser.add_argument( 179 "--count_exclude_pattern", 180 type=str, 181 default="", 182 help="Pattern to exclude tensors, e.g., from optimizers, when counting.") 183 parser.add_argument( 184 "--all_tensors", 185 nargs="?", 186 const=True, 187 type="bool", 188 default=False, 189 help="If True, print the names and values of all the tensors.") 190 parser.add_argument( 191 "--all_tensor_names", 192 nargs="?", 193 const=True, 194 type="bool", 195 default=False, 196 help="If True, print the names of all the tensors.") 197 parser.add_argument( 198 "--printoptions", 199 nargs="*", 200 type=parse_numpy_printoption, 201 help="Argument for numpy.set_printoptions(), in the form 'k=v'.") 202 FLAGS, unparsed = parser.parse_known_args() 203 app.run(main=main, argv=[sys.argv[0]] + unparsed) 204