• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Communicating tracebacks and source code with debug server."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import socket
22
23import grpc
24
25from tensorflow.core.debug import debug_service_pb2
26from tensorflow.core.protobuf import debug_pb2
27from tensorflow.python.debug.lib import common
28from tensorflow.python.debug.lib import debug_service_pb2_grpc
29from tensorflow.python.debug.lib import source_utils
30from tensorflow.python.platform import gfile
31from tensorflow.python.profiler import tfprof_logger
32
33
34def _load_debugged_source_file(file_path, source_file_proto):
35  file_stat = gfile.Stat(file_path)
36  source_file_proto.host = socket.gethostname()
37  source_file_proto.file_path = file_path
38  source_file_proto.last_modified = file_stat.mtime_nsec
39  source_file_proto.bytes = file_stat.length
40  try:
41    with gfile.Open(file_path, "r") as f:
42      source_file_proto.lines.extend(f.read().splitlines())
43  except IOError:
44    pass
45
46
47def _string_to_id(string, string_to_id):
48  if string not in string_to_id:
49    string_to_id[string] = len(string_to_id)
50  return string_to_id[string]
51
52
53def _format_origin_stack(origin_stack, call_traceback_proto):
54  """Format a traceback stack for a `CallTraceback` proto.
55
56  Args:
57    origin_stack: The stack list as returned by `traceback.extract_stack()`.
58    call_traceback_proto: A `CallTraceback` proto whose fields are to be
59      populated.
60  """
61  string_to_id = {}
62  string_to_id[None] = 0
63  for frame in origin_stack:
64    file_path, lineno, func_name, line_text = frame
65    call_traceback_proto.origin_stack.traces.add(
66        file_id=_string_to_id(file_path, string_to_id),
67        lineno=lineno,
68        function_id=_string_to_id(func_name, string_to_id),
69        line_id=_string_to_id(line_text, string_to_id))
70
71  id_to_string = call_traceback_proto.origin_id_to_string
72  for key, value in string_to_id.items():
73    id_to_string[value] = key if key is not None else ""
74
75
76def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
77  """Extract source file paths outside TensorFlow Python library.
78
79  Args:
80    code_defs: An iterable of `CodeDef` protos, i.e., an iterable of stack
81      traces.
82    id_to_string: A proto map from integer ids to strings.
83
84  Returns:
85    An iterable of source file paths outside the TensorFlow Python library.
86  """
87  file_ids = set()
88  for code_def in code_defs:
89    for trace in code_def.traces:
90      file_ids.add(trace.file_id)
91  non_tf_files = (id_to_string[file_id] for file_id in file_ids)
92  non_tf_files = (
93      f for f in non_tf_files
94      if not source_utils.guess_is_tensorflow_py_library(f) and gfile.Exists(f))
95  return non_tf_files
96
97
98def _send_call_tracebacks(destinations,
99                          origin_stack,
100                          is_eager_execution=False,
101                          call_key=None,
102                          graph=None,
103                          send_source=True):
104  """Send the tracebacks of a TensorFlow execution call.
105
106  To gRPC debug server(s). This applies to graph execution (`tf.Session.run()`)
107  calls and eager execution calls.
108
109  If `send_source`, also sends the underlying source files outside the
110  TensorFlow library.
111
112  Args:
113    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
114      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
115      `CallTraceback` proto payload will be sent to all the destinations.
116    origin_stack: The traceback stack for the origin of the execution call. For
117      graph execution, this is the traceback of the `tf.Session.run()`
118      invocation. For eager execution, this is the traceback of the Python
119      line that executes the eager operation.
120    is_eager_execution: (`bool`) whether an eager execution call (i.e., not a
121      `tf.Session.run` or derived methods) is being sent.
122    call_key: The key of the execution call, as a string. For graph execution,
123      this is a string describing the feeds, fetches (and targets) names of the
124      `tf.Session.run` call. For eager execution, this is ignored.
125    graph: A Python `tf.Graph` object (i.e., *not* a `tf.compat.v1.GraphDef`),
126      which contains op tracebacks, if applicable.
127    send_source: Whether the source files involved in the op tracebacks but
128      outside the TensorFlow library are to be sent.
129  """
130  if not isinstance(destinations, list):
131    destinations = [destinations]
132  # Strip grpc:// prefix, if any is present.
133  destinations = [
134      dest[len(common.GRPC_URL_PREFIX):]
135      if dest.startswith(common.GRPC_URL_PREFIX) else dest
136      for dest in destinations]
137
138  call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
139               if is_eager_execution
140               else debug_service_pb2.CallTraceback.GRAPH_EXECUTION)
141  graph_traceback = tfprof_logger.merge_default_with_oplog(
142      graph, add_trainable_var=False) if graph else None
143  call_traceback = debug_service_pb2.CallTraceback(
144      call_type=call_type, call_key=call_key, graph_traceback=graph_traceback,
145      graph_version=graph.version if graph else None)
146
147  _format_origin_stack(origin_stack, call_traceback)
148
149  if send_source:
150    source_file_paths = set()
151    source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
152        (log_entry.code_def for log_entry
153         in call_traceback.graph_traceback.log_entries),
154        call_traceback.graph_traceback.id_to_string))
155    source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
156        [call_traceback.origin_stack], call_traceback.origin_id_to_string))
157
158    debugged_source_files = []
159    for file_path in source_file_paths:
160      source_files = debug_pb2.DebuggedSourceFiles()
161      _load_debugged_source_file(
162          file_path, source_files.source_files.add())
163      debugged_source_files.append(source_files)
164
165  for destination in destinations:
166    no_max_message_sizes = [("grpc.max_receive_message_length", -1),
167                            ("grpc.max_send_message_length", -1)]
168    channel = grpc.insecure_channel(destination, options=no_max_message_sizes)
169    stub = debug_service_pb2_grpc.EventListenerStub(channel)
170    stub.SendTracebacks(call_traceback)
171    if send_source:
172      for source_files in debugged_source_files:
173        stub.SendSourceFiles(source_files)
174
175
176def send_graph_tracebacks(destinations,
177                          run_key,
178                          origin_stack,
179                          graph,
180                          send_source=True):
181  """Send the tracebacks of a graph execution call to debug server(s).
182
183  Args:
184    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
185      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
186      `CallTraceback` proto payload will be sent to all the destinations.
187    run_key: A string describing the feeds, fetches (and targets) names of the
188      `tf.Session.run` call.
189    origin_stack: The traceback of the `tf.Session.run()` invocation.
190    graph: A Python `tf.Graph` object (i.e., *not* a `tf.compat.v1.GraphDef`),
191      which contains op tracebacks.
192    send_source: Whether the source files involved in the op tracebacks but
193      outside the TensorFlow library are to be sent.
194  """
195  _send_call_tracebacks(
196      destinations, origin_stack, is_eager_execution=False, call_key=run_key,
197      graph=graph, send_source=send_source)
198
199
200def send_eager_tracebacks(destinations,
201                          origin_stack,
202                          send_source=True):
203  """Send the tracebacks of an eager execution call to debug server(s).
204
205  Args:
206    destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
207      e.g., "localhost:4242". If a `list`, gRPC requests containing the same
208    origin_stack: The traceback of the eager operation invocation.
209    send_source: Whether the source files involved in the op tracebacks but
210      outside the TensorFlow library are to be sent.
211  """
212  _send_call_tracebacks(
213      destinations, origin_stack, is_eager_execution=True,
214      send_source=send_source)
215