1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A simple script for inspect checkpoint files.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import re 22import sys 23 24from absl import app 25import numpy as np 26 27from tensorflow.python.framework import errors_impl 28from tensorflow.python.platform import flags 29from tensorflow.python.training import py_checkpoint_reader 30 31FLAGS = None 32 33 34def _count_total_params(reader, count_exclude_pattern=""): 35 """Count total number of variables.""" 36 var_to_shape_map = reader.get_variable_to_shape_map() 37 38 # Filter out tensors that we don't want to count 39 if count_exclude_pattern: 40 regex_pattern = re.compile(count_exclude_pattern) 41 new_var_to_shape_map = {} 42 exclude_num_tensors = 0 43 exclude_num_params = 0 44 for v in var_to_shape_map: 45 if regex_pattern.search(v): 46 exclude_num_tensors += 1 47 exclude_num_params += np.prod(var_to_shape_map[v]) 48 else: 49 new_var_to_shape_map[v] = var_to_shape_map[v] 50 var_to_shape_map = new_var_to_shape_map 51 print("# Excluding %d tensors (%d params) that match %s when counting." % ( 52 exclude_num_tensors, exclude_num_params, count_exclude_pattern)) 53 54 var_sizes = [np.prod(var_to_shape_map[v]) for v in var_to_shape_map] 55 return np.sum(var_sizes, dtype=int) 56 57 58def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, 59 all_tensor_names=False, 60 count_exclude_pattern=""): 61 """Prints tensors in a checkpoint file. 62 63 If no `tensor_name` is provided, prints the tensor names and shapes 64 in the checkpoint file. 65 66 If `tensor_name` is provided, prints the content of the tensor. 67 68 Args: 69 file_name: Name of the checkpoint file. 70 tensor_name: Name of the tensor in the checkpoint file to print. 71 all_tensors: Boolean indicating whether to print all tensors. 72 all_tensor_names: Boolean indicating whether to print all tensor names. 73 count_exclude_pattern: Regex string, pattern to exclude tensors when count. 74 """ 75 try: 76 reader = py_checkpoint_reader.NewCheckpointReader(file_name) 77 if all_tensors or all_tensor_names: 78 var_to_shape_map = reader.get_variable_to_shape_map() 79 var_to_dtype_map = reader.get_variable_to_dtype_map() 80 for key, value in sorted(var_to_shape_map.items()): 81 print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value)) 82 if all_tensors: 83 try: 84 print(reader.get_tensor(key)) 85 except errors_impl.InternalError: 86 print("<not convertible to a numpy dtype>") 87 elif not tensor_name: 88 print(reader.debug_string().decode("utf-8", errors="ignore")) 89 else: 90 if not reader.has_tensor(tensor_name): 91 print("Tensor %s not found in checkpoint" % tensor_name) 92 return 93 94 var_to_shape_map = reader.get_variable_to_shape_map() 95 var_to_dtype_map = reader.get_variable_to_dtype_map() 96 print("tensor: %s (%s) %s" % 97 (tensor_name, var_to_dtype_map[tensor_name].name, 98 var_to_shape_map[tensor_name])) 99 print(reader.get_tensor(tensor_name)) 100 101 # Count total number of parameters 102 print("# Total number of params: %d" % _count_total_params( 103 reader, count_exclude_pattern=count_exclude_pattern)) 104 except Exception as e: # pylint: disable=broad-except 105 print(str(e)) 106 if "corrupted compressed block contents" in str(e): 107 print("It's likely that your checkpoint file has been compressed " 108 "with SNAPPY.") 109 if ("Data loss" in str(e) and 110 any(e in file_name for e in [".index", ".meta", ".data"])): 111 proposed_file = ".".join(file_name.split(".")[0:-1]) 112 v2_file_error_template = """ 113It's likely that this is a V2 checkpoint and you need to provide the filename 114*prefix*. Try removing the '.' and extension. Try: 115inspect checkpoint --file_name = {}""" 116 print(v2_file_error_template.format(proposed_file)) 117 118 119def parse_numpy_printoption(kv_str): 120 """Sets a single numpy printoption from a string of the form 'x=y'. 121 122 See documentation on numpy.set_printoptions() for details about what values 123 x and y can take. x can be any option listed there other than 'formatter'. 124 125 Args: 126 kv_str: A string of the form 'x=y', such as 'threshold=100000' 127 128 Raises: 129 argparse.ArgumentTypeError: If the string couldn't be used to set any 130 nump printoption. 131 """ 132 k_v_str = kv_str.split("=", 1) 133 if len(k_v_str) != 2 or not k_v_str[0]: 134 raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str) 135 k, v_str = k_v_str 136 printoptions = np.get_printoptions() 137 if k not in printoptions: 138 raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k) 139 v_type = type(printoptions[k]) 140 if v_type is type(None): 141 raise argparse.ArgumentTypeError( 142 "Setting '%s' from the command line is not supported." % k) 143 try: 144 v = ( 145 v_type(v_str) 146 if v_type is not bool else flags.BooleanParser().parse(v_str)) 147 except ValueError as e: 148 raise argparse.ArgumentTypeError(e.message) 149 np.set_printoptions(**{k: v}) 150 151 152def main(unused_argv): 153 if not FLAGS.file_name: 154 print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " 155 "[--tensor_name=tensor_to_print] " 156 "[--all_tensors] " 157 "[--all_tensor_names] " 158 "[--printoptions]") 159 sys.exit(1) 160 else: 161 print_tensors_in_checkpoint_file( 162 FLAGS.file_name, FLAGS.tensor_name, 163 FLAGS.all_tensors, FLAGS.all_tensor_names, 164 count_exclude_pattern=FLAGS.count_exclude_pattern) 165 166 167if __name__ == "__main__": 168 parser = argparse.ArgumentParser() 169 parser.register("type", "bool", lambda v: v.lower() == "true") 170 parser.add_argument( 171 "--file_name", 172 type=str, 173 default="", 174 help="Checkpoint filename. " 175 "Note, if using Checkpoint V2 format, file_name is the " 176 "shared prefix between all files in the checkpoint.") 177 parser.add_argument( 178 "--tensor_name", 179 type=str, 180 default="", 181 help="Name of the tensor to inspect") 182 parser.add_argument( 183 "--count_exclude_pattern", 184 type=str, 185 default="", 186 help="Pattern to exclude tensors, e.g., from optimizers, when counting.") 187 parser.add_argument( 188 "--all_tensors", 189 nargs="?", 190 const=True, 191 type="bool", 192 default=False, 193 help="If True, print the names and values of all the tensors.") 194 parser.add_argument( 195 "--all_tensor_names", 196 nargs="?", 197 const=True, 198 type="bool", 199 default=False, 200 help="If True, print the names of all the tensors.") 201 parser.add_argument( 202 "--printoptions", 203 nargs="*", 204 type=parse_numpy_printoption, 205 help="Argument for numpy.set_printoptions(), in the form 'k=v'.") 206 FLAGS, unparsed = parser.parse_known_args() 207 app.run(main=main, argv=[sys.argv[0]] + unparsed) 208