• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes and functions that help to inspect Python source w.r.t. TF graphs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import re
24import zipfile
25
26import absl
27import numpy as np
28
29from tensorflow.python.debug.lib import profiling
30
31
32_TENSORFLOW_BASEDIR = os.path.dirname(
33    os.path.dirname(os.path.dirname(os.path.dirname(
34        os.path.normpath(os.path.abspath(__file__))))))
35
36_ABSL_BASEDIR = os.path.dirname(absl.__file__)
37
38
39UNCOMPILED_SOURCE_SUFFIXES = (".py")
40COMPILED_SOURCE_SUFFIXES = (".pyc", ".pyo")
41
42
43def _norm_abs_path(file_path):
44  return os.path.normpath(os.path.abspath(file_path))
45
46
47def is_extension_uncompiled_python_source(file_path):
48  _, extension = os.path.splitext(file_path)
49  return extension.lower() in UNCOMPILED_SOURCE_SUFFIXES
50
51
52def is_extension_compiled_python_source(file_path):
53  _, extension = os.path.splitext(file_path)
54  return extension.lower() in COMPILED_SOURCE_SUFFIXES
55
56
57def _convert_watch_key_to_tensor_name(watch_key):
58  return watch_key[:watch_key.rfind(":")]
59
60
61def guess_is_tensorflow_py_library(py_file_path):
62  """Guess whether a Python source file is a part of the tensorflow library.
63
64  Special cases:
65    1) Returns False for unit-test files in the library (*_test.py),
66    2) Returns False for files under python/debug/examples.
67
68  Args:
69    py_file_path: full path of the Python source file in question.
70
71  Returns:
72    (`bool`) Whether the file is inferred to be a part of the tensorflow
73      library.
74  """
75  if (not is_extension_uncompiled_python_source(py_file_path) and
76      not is_extension_compiled_python_source(py_file_path)):
77    return False
78  py_file_path = _norm_abs_path(py_file_path)
79  return ((py_file_path.startswith(_TENSORFLOW_BASEDIR) or
80           py_file_path.startswith(_ABSL_BASEDIR)) and
81          not py_file_path.endswith("_test.py") and
82          (os.path.normpath("tensorflow/python/debug/examples") not in
83           os.path.normpath(py_file_path)))
84
85
86def load_source(source_file_path):
87  """Load the content of a Python source code file.
88
89  This function covers the following case:
90    1. source_file_path points to an existing Python (.py) file on the
91       file system.
92    2. source_file_path is a path within a .par file (i.e., a zip-compressed,
93       self-contained Python executable).
94
95  Args:
96    source_file_path: Path to the Python source file to read.
97
98  Returns:
99    A length-2 tuple:
100      - Lines of the source file, as a `list` of `str`s.
101      - The width of the string needed to show the line number in the file.
102        This is calculated based on the number of lines in the source file.
103
104  Raises:
105    IOError: if loading is unsuccessful.
106  """
107  if os.path.isfile(source_file_path):
108    with open(source_file_path, "rb") as f:
109      source_text = f.read().decode("utf-8")
110    source_lines = source_text.split("\n")
111  else:
112    # One possible reason why the file doesn't exist is that it's a path
113    # inside a .par file. Try that possibility.
114    source_lines = _try_load_par_source(source_file_path)
115    if source_lines is None:
116      raise IOError(
117          "Source path neither exists nor can be loaded as a .par file: %s" %
118          source_file_path)
119  line_num_width = int(np.ceil(np.log10(len(source_lines)))) + 3
120  return source_lines, line_num_width
121
122
123def _try_load_par_source(source_file_path):
124  """Try loading the source code inside a .par file.
125
126  A .par file is a zip-compressed, self-contained Python executable.
127  It contains the content of individual Python source files that can
128  be read only through extracting from the zip file.
129
130  Args:
131    source_file_path: The full path to the file inside the .par file. This
132      path should include the path to the .par file itself, followed by the
133      intra-par path, e.g.,
134      "/tmp/my_executable.par/org-tensorflow/tensorflow/python/foo/bar.py".
135
136  Returns:
137    If successful, lines of the source file as a `list` of `str`s.
138    Else, `None`.
139  """
140  prefix_path = source_file_path
141  while True:
142    prefix_path, basename = os.path.split(prefix_path)
143    if not basename:
144      break
145    suffix_path = os.path.normpath(
146        os.path.relpath(source_file_path, start=prefix_path))
147    if prefix_path.endswith(".par") and os.path.isfile(prefix_path):
148      with zipfile.ZipFile(prefix_path) as z:
149        norm_names = [os.path.normpath(name) for name in z.namelist()]
150        if suffix_path in norm_names:
151          with z.open(z.namelist()[norm_names.index(suffix_path)]) as zf:
152            source_text = zf.read().decode("utf-8")
153            return source_text.split("\n")
154
155
156def annotate_source(dump,
157                    source_file_path,
158                    do_dumped_tensors=False,
159                    file_stack_top=False,
160                    min_line=None,
161                    max_line=None):
162  """Annotate a Python source file with a list of ops created at each line.
163
164  (The annotation doesn't change the source file itself.)
165
166  Args:
167    dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
168      has been loaded.
169    source_file_path: (`str`) Path to the source file being annotated.
170    do_dumped_tensors: (`str`) Whether dumped Tensors, instead of ops are to be
171      used to annotate the source file.
172    file_stack_top: (`bool`) Whether only the top stack trace in the
173      specified source file is to be annotated.
174    min_line: (`None` or `int`) The 1-based line to start annotate the source
175      file from (inclusive).
176    max_line: (`None` or `int`) The 1-based line number to end the annotation
177      at (exclusive).
178
179  Returns:
180    A `dict` mapping 1-based line number to a list of op name(s) created at
181      that line, or tensor names if `do_dumped_tensors` is True.
182
183  Raises:
184    ValueError: If the dump object does not have a Python graph set.
185  """
186
187  py_graph = dump.python_graph
188  if not py_graph:
189    raise ValueError("Cannot perform source annotation due to a lack of set "
190                     "Python graph in the dump object")
191
192  source_file_path = _norm_abs_path(source_file_path)
193
194  line_to_op_names = {}
195  for op in py_graph.get_operations():
196    for file_path, line_number, _, _ in reversed(dump.node_traceback(op.name)):
197      if (min_line is not None and line_number < min_line or
198          max_line is not None and line_number >= max_line):
199        continue
200
201      if _norm_abs_path(file_path) != source_file_path:
202        continue
203
204      if do_dumped_tensors:
205        watch_keys = dump.debug_watch_keys(op.name)
206        # Convert watch keys to unique Tensor names.
207        items_to_append = list(
208            set(map(_convert_watch_key_to_tensor_name, watch_keys)))
209      else:
210        items_to_append = [op.name]
211
212      if line_number in line_to_op_names:
213        line_to_op_names[line_number].extend(items_to_append)
214      else:
215        line_to_op_names[line_number] = items_to_append
216
217      if file_stack_top:
218        break
219
220  return line_to_op_names
221
222
223def list_source_files_against_dump(dump,
224                                   path_regex_allowlist=None,
225                                   node_name_regex_allowlist=None):
226  """Generate a list of source files with information regarding ops and tensors.
227
228  Args:
229    dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
230      has been loaded.
231    path_regex_allowlist: A regular-expression filter for source file path.
232    node_name_regex_allowlist: A regular-expression filter for node names.
233
234  Returns:
235    A list of tuples regarding the Python source files involved in constructing
236    the ops and tensors contained in `dump`. Each tuple is:
237      (source_file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
238       first_line)
239
240      is_tf_library: (`bool`) A guess of whether the file belongs to the
241        TensorFlow Python library.
242      num_nodes: How many nodes were created by lines of this source file.
243        These include nodes with dumps and those without.
244      num_tensors: How many Tensors were created by lines of this source file.
245        These include Tensors with dumps and those without.
246      num_dumps: How many debug Tensor dumps were from nodes (and Tensors)
247        that were created by this source file.
248      first_line: The first line number (1-based) that created any nodes or
249        Tensors in this source file.
250
251    The list is sorted by ascending order of source_file_path.
252
253  Raises:
254    ValueError: If the dump object does not have a Python graph set.
255  """
256
257  py_graph = dump.python_graph
258  if not py_graph:
259    raise ValueError("Cannot generate source list due to a lack of set "
260                     "Python graph in the dump object")
261
262  path_to_node_names = collections.defaultdict(set)
263  path_to_tensor_names = collections.defaultdict(set)
264  path_to_first_line = {}
265  tensor_name_to_num_dumps = {}
266
267  path_regex = (
268      re.compile(path_regex_allowlist) if path_regex_allowlist else None)
269  node_name_regex = (
270      re.compile(node_name_regex_allowlist)
271      if node_name_regex_allowlist else None)
272
273  to_skip_file_paths = set()
274  for op in py_graph.get_operations():
275    if node_name_regex and not node_name_regex.match(op.name):
276      continue
277
278    for file_path, line_number, _, _ in dump.node_traceback(op.name):
279      file_path = _norm_abs_path(file_path)
280      if (file_path in to_skip_file_paths or
281          path_regex and not path_regex.match(file_path) or
282          not os.path.isfile(file_path)):
283        to_skip_file_paths.add(file_path)
284        continue
285
286      path_to_node_names[file_path].add(op.name)
287      if file_path in path_to_first_line:
288        if path_to_first_line[file_path] > line_number:
289          path_to_first_line[file_path] = line_number
290      else:
291        path_to_first_line[file_path] = line_number
292
293      for output_tensor in op.outputs:
294        tensor_name = output_tensor.name
295        path_to_tensor_names[file_path].add(tensor_name)
296
297      watch_keys = dump.debug_watch_keys(op.name)
298      for watch_key in watch_keys:
299        node_name, output_slot, debug_op = watch_key.split(":")
300        tensor_name = "%s:%s" % (node_name, output_slot)
301        if tensor_name not in tensor_name_to_num_dumps:
302          tensor_name_to_num_dumps[tensor_name] = len(
303              dump.get_tensors(node_name, int(output_slot), debug_op))
304
305  path_to_num_dumps = {}
306  for path in path_to_tensor_names:
307    path_to_num_dumps[path] = sum(
308        tensor_name_to_num_dumps.get(tensor_name, 0)
309        for tensor_name in path_to_tensor_names[path])
310
311  output = []
312  for file_path in path_to_node_names:
313    output.append((
314        file_path,
315        guess_is_tensorflow_py_library(file_path),
316        len(path_to_node_names.get(file_path, {})),
317        len(path_to_tensor_names.get(file_path, {})),
318        path_to_num_dumps.get(file_path, 0),
319        path_to_first_line[file_path]))
320
321  return sorted(output, key=lambda x: x[0])
322
323
324def annotate_source_against_profile(profile_data,
325                                    source_file_path,
326                                    node_name_filter=None,
327                                    op_type_filter=None,
328                                    min_line=None,
329                                    max_line=None):
330  """Annotate a Python source file with profiling information at each line.
331
332  (The annotation doesn't change the source file itself.)
333
334  Args:
335    profile_data: (`list` of `ProfileDatum`) A list of `ProfileDatum`.
336    source_file_path: (`str`) Path to the source file being annotated.
337    node_name_filter: Regular expression to filter by node name.
338    op_type_filter: Regular expression to filter by op type.
339    min_line: (`None` or `int`) The 1-based line to start annotate the source
340      file from (inclusive).
341    max_line: (`None` or `int`) The 1-based line number to end the annotation
342      at (exclusive).
343
344  Returns:
345    A `dict` mapping 1-based line number to a the namedtuple
346      `profiling.LineOrFuncProfileSummary`.
347  """
348
349  source_file_path = _norm_abs_path(source_file_path)
350
351  node_name_regex = re.compile(node_name_filter) if node_name_filter else None
352  op_type_regex = re.compile(op_type_filter) if op_type_filter else None
353
354  line_to_profile_summary = {}
355  for profile_datum in profile_data:
356    if not profile_datum.file_path:
357      continue
358
359    if _norm_abs_path(profile_datum.file_path) != source_file_path:
360      continue
361
362    if (min_line is not None and profile_datum.line_number < min_line or
363        max_line is not None and profile_datum.line_number >= max_line):
364      continue
365
366    if (node_name_regex and
367        not node_name_regex.match(profile_datum.node_exec_stats.node_name)):
368      continue
369
370    if op_type_regex and not op_type_regex.match(profile_datum.op_type):
371      continue
372
373    if profile_datum.line_number not in line_to_profile_summary:
374      line_to_profile_summary[profile_datum.line_number] = (
375          profiling.AggregateProfile(profile_datum))
376    else:
377      line_to_profile_summary[profile_datum.line_number].add(profile_datum)
378
379  return line_to_profile_summary
380