• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A simple script for inspect checkpoint files."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import re
22import sys
23
24from absl import app
25import numpy as np
26
27from tensorflow.python.framework import errors_impl
28from tensorflow.python.platform import flags
29from tensorflow.python.training import py_checkpoint_reader
30
31FLAGS = None
32
33
34def _count_total_params(reader, count_exclude_pattern=""):
35  """Count total number of variables."""
36  var_to_shape_map = reader.get_variable_to_shape_map()
37
38  # Filter out tensors that we don't want to count
39  if count_exclude_pattern:
40    regex_pattern = re.compile(count_exclude_pattern)
41    new_var_to_shape_map = {}
42    exclude_num_tensors = 0
43    exclude_num_params = 0
44    for v in var_to_shape_map:
45      if regex_pattern.search(v):
46        exclude_num_tensors += 1
47        exclude_num_params += np.prod(var_to_shape_map[v])
48      else:
49        new_var_to_shape_map[v] = var_to_shape_map[v]
50    var_to_shape_map = new_var_to_shape_map
51    print("# Excluding %d tensors (%d params) that match %s when counting." % (
52        exclude_num_tensors, exclude_num_params, count_exclude_pattern))
53
54  var_sizes = [np.prod(var_to_shape_map[v]) for v in var_to_shape_map]
55  return np.sum(var_sizes, dtype=int)
56
57
58def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
59                                     all_tensor_names=False,
60                                     count_exclude_pattern=""):
61  """Prints tensors in a checkpoint file.
62
63  If no `tensor_name` is provided, prints the tensor names and shapes
64  in the checkpoint file.
65
66  If `tensor_name` is provided, prints the content of the tensor.
67
68  Args:
69    file_name: Name of the checkpoint file.
70    tensor_name: Name of the tensor in the checkpoint file to print.
71    all_tensors: Boolean indicating whether to print all tensors.
72    all_tensor_names: Boolean indicating whether to print all tensor names.
73    count_exclude_pattern: Regex string, pattern to exclude tensors when count.
74  """
75  try:
76    reader = py_checkpoint_reader.NewCheckpointReader(file_name)
77    if all_tensors or all_tensor_names:
78      var_to_shape_map = reader.get_variable_to_shape_map()
79      var_to_dtype_map = reader.get_variable_to_dtype_map()
80      for key, value in sorted(var_to_shape_map.items()):
81        print("tensor: %s (%s) %s" % (key, var_to_dtype_map[key].name, value))
82        if all_tensors:
83          try:
84            print(reader.get_tensor(key))
85          except errors_impl.InternalError:
86            print("<not convertible to a numpy dtype>")
87    elif not tensor_name:
88      print(reader.debug_string().decode("utf-8", errors="ignore"))
89    else:
90      if not reader.has_tensor(tensor_name):
91        print("Tensor %s not found in checkpoint" % tensor_name)
92        return
93
94      var_to_shape_map = reader.get_variable_to_shape_map()
95      var_to_dtype_map = reader.get_variable_to_dtype_map()
96      print("tensor: %s (%s) %s" %
97            (tensor_name, var_to_dtype_map[tensor_name].name,
98             var_to_shape_map[tensor_name]))
99      print(reader.get_tensor(tensor_name))
100
101    # Count total number of parameters
102    print("# Total number of params: %d" % _count_total_params(
103        reader, count_exclude_pattern=count_exclude_pattern))
104  except Exception as e:  # pylint: disable=broad-except
105    print(str(e))
106    if "corrupted compressed block contents" in str(e):
107      print("It's likely that your checkpoint file has been compressed "
108            "with SNAPPY.")
109    if ("Data loss" in str(e) and
110        any(e in file_name for e in [".index", ".meta", ".data"])):
111      proposed_file = ".".join(file_name.split(".")[0:-1])
112      v2_file_error_template = """
113It's likely that this is a V2 checkpoint and you need to provide the filename
114*prefix*.  Try removing the '.' and extension.  Try:
115inspect checkpoint --file_name = {}"""
116      print(v2_file_error_template.format(proposed_file))
117
118
119def parse_numpy_printoption(kv_str):
120  """Sets a single numpy printoption from a string of the form 'x=y'.
121
122  See documentation on numpy.set_printoptions() for details about what values
123  x and y can take. x can be any option listed there other than 'formatter'.
124
125  Args:
126    kv_str: A string of the form 'x=y', such as 'threshold=100000'
127
128  Raises:
129    argparse.ArgumentTypeError: If the string couldn't be used to set any
130        nump printoption.
131  """
132  k_v_str = kv_str.split("=", 1)
133  if len(k_v_str) != 2 or not k_v_str[0]:
134    raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str)
135  k, v_str = k_v_str
136  printoptions = np.get_printoptions()
137  if k not in printoptions:
138    raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k)
139  v_type = type(printoptions[k])
140  if v_type is type(None):
141    raise argparse.ArgumentTypeError(
142        "Setting '%s' from the command line is not supported." % k)
143  try:
144    v = (
145        v_type(v_str)
146        if v_type is not bool else flags.BooleanParser().parse(v_str))
147  except ValueError as e:
148    raise argparse.ArgumentTypeError(e.message)
149  np.set_printoptions(**{k: v})
150
151
152def main(unused_argv):
153  if not FLAGS.file_name:
154    print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
155          "[--tensor_name=tensor_to_print] "
156          "[--all_tensors] "
157          "[--all_tensor_names] "
158          "[--printoptions]")
159    sys.exit(1)
160  else:
161    print_tensors_in_checkpoint_file(
162        FLAGS.file_name, FLAGS.tensor_name,
163        FLAGS.all_tensors, FLAGS.all_tensor_names,
164        count_exclude_pattern=FLAGS.count_exclude_pattern)
165
166
167if __name__ == "__main__":
168  parser = argparse.ArgumentParser()
169  parser.register("type", "bool", lambda v: v.lower() == "true")
170  parser.add_argument(
171      "--file_name",
172      type=str,
173      default="",
174      help="Checkpoint filename. "
175      "Note, if using Checkpoint V2 format, file_name is the "
176      "shared prefix between all files in the checkpoint.")
177  parser.add_argument(
178      "--tensor_name",
179      type=str,
180      default="",
181      help="Name of the tensor to inspect")
182  parser.add_argument(
183      "--count_exclude_pattern",
184      type=str,
185      default="",
186      help="Pattern to exclude tensors, e.g., from optimizers, when counting.")
187  parser.add_argument(
188      "--all_tensors",
189      nargs="?",
190      const=True,
191      type="bool",
192      default=False,
193      help="If True, print the names and values of all the tensors.")
194  parser.add_argument(
195      "--all_tensor_names",
196      nargs="?",
197      const=True,
198      type="bool",
199      default=False,
200      help="If True, print the names of all the tensors.")
201  parser.add_argument(
202      "--printoptions",
203      nargs="*",
204      type=parse_numpy_printoption,
205      help="Argument for numpy.set_printoptions(), in the form 'k=v'.")
206  FLAGS, unparsed = parser.parse_known_args()
207  app.run(main=main, argv=[sys.argv[0]] + unparsed)
208