1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shared functions and classes for tfdbg command-line interface.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import math 21 22import numpy as np 23import six 24 25from tensorflow.python.debug.cli import command_parser 26from tensorflow.python.debug.cli import debugger_cli_common 27from tensorflow.python.debug.cli import tensor_format 28from tensorflow.python.debug.lib import common 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import gfile 32 33RL = debugger_cli_common.RichLine 34 35# Default threshold number of elements above which ellipses will be used 36# when printing the value of the tensor. 37DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000 38 39COLOR_BLACK = "black" 40COLOR_BLUE = "blue" 41COLOR_CYAN = "cyan" 42COLOR_GRAY = "gray" 43COLOR_GREEN = "green" 44COLOR_MAGENTA = "magenta" 45COLOR_RED = "red" 46COLOR_WHITE = "white" 47COLOR_YELLOW = "yellow" 48 49TIME_UNIT_US = "us" 50TIME_UNIT_MS = "ms" 51TIME_UNIT_S = "s" 52TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S] 53 54 55def bytes_to_readable_str(num_bytes, include_b=False): 56 """Generate a human-readable string representing number of bytes. 57 58 The units B, kB, MB and GB are used. 59 60 Args: 61 num_bytes: (`int` or None) Number of bytes. 62 include_b: (`bool`) Include the letter B at the end of the unit. 63 64 Returns: 65 (`str`) A string representing the number of bytes in a human-readable way, 66 including a unit at the end. 67 """ 68 69 if num_bytes is None: 70 return str(num_bytes) 71 if num_bytes < 1024: 72 result = "%d" % num_bytes 73 elif num_bytes < 1048576: 74 result = "%.2fk" % (num_bytes / 1024.0) 75 elif num_bytes < 1073741824: 76 result = "%.2fM" % (num_bytes / 1048576.0) 77 else: 78 result = "%.2fG" % (num_bytes / 1073741824.0) 79 80 if include_b: 81 result += "B" 82 return result 83 84 85def time_to_readable_str(value_us, force_time_unit=None): 86 """Convert time value to human-readable string. 87 88 Args: 89 value_us: time value in microseconds. 90 force_time_unit: force the output to use the specified time unit. Must be 91 in TIME_UNITS. 92 93 Returns: 94 Human-readable string representation of the time value. 95 96 Raises: 97 ValueError: if force_time_unit value is not in TIME_UNITS. 98 """ 99 if not value_us: 100 return "0" 101 if force_time_unit: 102 if force_time_unit not in TIME_UNITS: 103 raise ValueError("Invalid time unit: %s" % force_time_unit) 104 order = TIME_UNITS.index(force_time_unit) 105 time_unit = force_time_unit 106 return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 107 else: 108 order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3)) 109 time_unit = TIME_UNITS[order] 110 return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 111 112 113def parse_ranges_highlight(ranges_string): 114 """Process ranges highlight string. 115 116 Args: 117 ranges_string: (str) A string representing a numerical range of a list of 118 numerical ranges. See the help info of the -r flag of the print_tensor 119 command for more details. 120 121 Returns: 122 An instance of tensor_format.HighlightOptions, if range_string is a valid 123 representation of a range or a list of ranges. 124 """ 125 126 ranges = None 127 128 def ranges_filter(x): 129 r = np.zeros(x.shape, dtype=bool) 130 for range_start, range_end in ranges: 131 r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end)) 132 133 return r 134 135 if ranges_string: 136 ranges = command_parser.parse_ranges(ranges_string) 137 return tensor_format.HighlightOptions( 138 ranges_filter, description=ranges_string) 139 else: 140 return None 141 142 143def numpy_printoptions_from_screen_info(screen_info): 144 if screen_info and "cols" in screen_info: 145 return {"linewidth": screen_info["cols"]} 146 else: 147 return {} 148 149 150def format_tensor(tensor, 151 tensor_name, 152 np_printoptions, 153 print_all=False, 154 tensor_slicing=None, 155 highlight_options=None, 156 include_numeric_summary=False, 157 write_path=None): 158 """Generate formatted str to represent a tensor or its slices. 159 160 Args: 161 tensor: (numpy ndarray) The tensor value. 162 tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key. 163 np_printoptions: (dict) Numpy tensor formatting options. 164 print_all: (bool) Whether the tensor is to be displayed in its entirety, 165 instead of printing ellipses, even if its number of elements exceeds 166 the default numpy display threshold. 167 (Note: Even if this is set to true, the screen output can still be cut 168 off by the UI frontend if it consist of more lines than the frontend 169 can handle.) 170 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If 171 None, no slicing will be performed on the tensor. 172 highlight_options: (tensor_format.HighlightOptions) options to highlight 173 elements of the tensor. See the doc of tensor_format.format_tensor() 174 for more details. 175 include_numeric_summary: Whether a text summary of the numeric values (if 176 applicable) will be included. 177 write_path: A path to save the tensor value (after any slicing) to 178 (optional). `numpy.save()` is used to save the value. 179 180 Returns: 181 An instance of `debugger_cli_common.RichTextLines` representing the 182 (potentially sliced) tensor. 183 """ 184 185 if tensor_slicing: 186 # Validate the indexing. 187 value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing) 188 sliced_name = tensor_name + tensor_slicing 189 else: 190 value = tensor 191 sliced_name = tensor_name 192 193 auxiliary_message = None 194 if write_path: 195 with gfile.Open(write_path, "wb") as output_file: 196 np.save(output_file, value) 197 line = debugger_cli_common.RichLine("Saved value to: ") 198 line += debugger_cli_common.RichLine(write_path, font_attr="bold") 199 line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length) 200 auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list( 201 [line, debugger_cli_common.RichLine("")]) 202 203 if print_all: 204 np_printoptions["threshold"] = value.size 205 else: 206 np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD 207 208 return tensor_format.format_tensor( 209 value, 210 sliced_name, 211 include_metadata=True, 212 include_numeric_summary=include_numeric_summary, 213 auxiliary_message=auxiliary_message, 214 np_printoptions=np_printoptions, 215 highlight_options=highlight_options) 216 217 218def error(msg): 219 """Generate a RichTextLines output for error. 220 221 Args: 222 msg: (str) The error message. 223 224 Returns: 225 (debugger_cli_common.RichTextLines) A representation of the error message 226 for screen output. 227 """ 228 229 return debugger_cli_common.rich_text_lines_from_rich_line_list([ 230 RL("ERROR: " + msg, COLOR_RED)]) 231 232 233def _recommend_command(command, description, indent=2, create_link=False): 234 """Generate a RichTextLines object that describes a recommended command. 235 236 Args: 237 command: (str) The command to recommend. 238 description: (str) A description of what the command does. 239 indent: (int) How many spaces to indent in the beginning. 240 create_link: (bool) Whether a command link is to be applied to the command 241 string. 242 243 Returns: 244 (RichTextLines) Formatted text (with font attributes) for recommending the 245 command. 246 """ 247 248 indent_str = " " * indent 249 250 if create_link: 251 font_attr = [debugger_cli_common.MenuItem("", command), "bold"] 252 else: 253 font_attr = "bold" 254 255 lines = [RL(indent_str) + RL(command, font_attr) + ":", 256 indent_str + " " + description] 257 258 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 259 260 261def get_tfdbg_logo(): 262 """Make an ASCII representation of the tfdbg logo.""" 263 264 lines = [ 265 "", 266 "TTTTTT FFFF DDD BBBB GGG ", 267 " TT F D D B B G ", 268 " TT FFF D D BBBB G GG", 269 " TT F D D B B G G", 270 " TT F DDD BBBB GGG ", 271 "", 272 ] 273 return debugger_cli_common.RichTextLines(lines) 274 275 276_HORIZONTAL_BAR = "======================================" 277 278 279def get_run_start_intro(run_call_count, 280 fetches, 281 feed_dict, 282 tensor_filters, 283 is_callable_runner=False): 284 """Generate formatted intro for run-start UI. 285 286 Args: 287 run_call_count: (int) Run call counter. 288 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 289 for more details. 290 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 291 for more details. 292 tensor_filters: (dict) A dict from tensor-filter name to tensor-filter 293 callable. 294 is_callable_runner: (bool) whether a runner returned by 295 Session.make_callable is being run. 296 297 Returns: 298 (RichTextLines) Formatted intro message about the `Session.run()` call. 299 """ 300 301 fetch_lines = common.get_flattened_names(fetches) 302 303 if not feed_dict: 304 feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] 305 else: 306 feed_dict_lines = [] 307 for feed_key in feed_dict: 308 feed_key_name = common.get_graph_element_name(feed_key) 309 feed_dict_line = debugger_cli_common.RichLine(" ") 310 feed_dict_line += debugger_cli_common.RichLine( 311 feed_key_name, 312 debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name)) 313 # Surround the name string with quotes, because feed_key_name may contain 314 # spaces in some cases, e.g., SparseTensors. 315 feed_dict_lines.append(feed_dict_line) 316 feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list( 317 feed_dict_lines) 318 319 out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR) 320 if is_callable_runner: 321 out.append("Running a runner returned by Session.make_callable()") 322 else: 323 out.append("Session.run() call #%d:" % run_call_count) 324 out.append("") 325 out.append("Fetch(es):") 326 out.extend(debugger_cli_common.RichTextLines( 327 [" " + line for line in fetch_lines])) 328 out.append("") 329 out.append("Feed dict:") 330 out.extend(feed_dict_lines) 331 out.append(_HORIZONTAL_BAR) 332 out.append("") 333 out.append("Select one of the following commands to proceed ---->") 334 335 out.extend( 336 _recommend_command( 337 "run", 338 "Execute the run() call with debug tensor-watching", 339 create_link=True)) 340 out.extend( 341 _recommend_command( 342 "run -n", 343 "Execute the run() call without debug tensor-watching", 344 create_link=True)) 345 out.extend( 346 _recommend_command( 347 "run -t <T>", 348 "Execute run() calls (T - 1) times without debugging, then " 349 "execute run() once more with debugging and drop back to the CLI")) 350 out.extend( 351 _recommend_command( 352 "run -f <filter_name>", 353 "Keep executing run() calls until a dumped tensor passes a given, " 354 "registered filter (conditional breakpoint mode)")) 355 356 more_lines = [" Registered filter(s):"] 357 if tensor_filters: 358 filter_names = [] 359 for filter_name in tensor_filters: 360 filter_names.append(filter_name) 361 command_menu_node = debugger_cli_common.MenuItem( 362 "", "run -f %s" % filter_name) 363 more_lines.append(RL(" * ") + RL(filter_name, command_menu_node)) 364 else: 365 more_lines.append(" (None)") 366 367 out.extend( 368 debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines)) 369 370 out.extend( 371 _recommend_command( 372 "invoke_stepper", 373 "Use the node-stepper interface, which allows you to interactively " 374 "step through nodes involved in the graph run() call and " 375 "inspect/modify their values", create_link=True)) 376 377 out.append("") 378 379 out.append_rich_line(RL("For more details, see ") + 380 RL("help.", debugger_cli_common.MenuItem("", "help")) + 381 ".") 382 out.append("") 383 384 # Make main menu for the run-start intro. 385 menu = debugger_cli_common.Menu() 386 menu.append(debugger_cli_common.MenuItem("run", "run")) 387 menu.append(debugger_cli_common.MenuItem( 388 "invoke_stepper", "invoke_stepper")) 389 menu.append(debugger_cli_common.MenuItem("exit", "exit")) 390 out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu 391 392 return out 393 394 395def get_run_short_description(run_call_count, 396 fetches, 397 feed_dict, 398 is_callable_runner=False): 399 """Get a short description of the run() call. 400 401 Args: 402 run_call_count: (int) Run call counter. 403 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 404 for more details. 405 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 406 for more details. 407 is_callable_runner: (bool) whether a runner returned by 408 Session.make_callable is being run. 409 410 Returns: 411 (str) A short description of the run() call, including information about 412 the fetche(s) and feed(s). 413 """ 414 if is_callable_runner: 415 return "runner from make_callable()" 416 417 description = "run #%d: " % run_call_count 418 419 if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): 420 description += "1 fetch (%s); " % common.get_graph_element_name(fetches) 421 else: 422 # Could be (nested) list, tuple, dict or namedtuple. 423 num_fetches = len(common.get_flattened_names(fetches)) 424 if num_fetches > 1: 425 description += "%d fetches; " % num_fetches 426 else: 427 description += "%d fetch; " % num_fetches 428 429 if not feed_dict: 430 description += "0 feeds" 431 else: 432 if len(feed_dict) == 1: 433 for key in feed_dict: 434 description += "1 feed (%s)" % ( 435 key if isinstance(key, six.string_types) or not hasattr(key, "name") 436 else key.name) 437 else: 438 description += "%d feeds" % len(feed_dict) 439 440 return description 441 442 443def get_error_intro(tf_error): 444 """Generate formatted intro for TensorFlow run-time error. 445 446 Args: 447 tf_error: (errors.OpError) TensorFlow run-time error object. 448 449 Returns: 450 (RichTextLines) Formatted intro message about the run-time OpError, with 451 sample commands for debugging. 452 """ 453 454 if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"): 455 op_name = tf_error.op.name 456 else: 457 op_name = None 458 459 intro_lines = [ 460 "--------------------------------------", 461 RL("!!! An error occurred during the run !!!", "blink"), 462 "", 463 ] 464 465 out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines) 466 467 if op_name is not None: 468 out.extend(debugger_cli_common.RichTextLines( 469 ["You may use the following commands to debug:"])) 470 out.extend( 471 _recommend_command("ni -a -d -t %s" % op_name, 472 "Inspect information about the failing op.", 473 create_link=True)) 474 out.extend( 475 _recommend_command("li -r %s" % op_name, 476 "List inputs to the failing op, recursively.", 477 create_link=True)) 478 479 out.extend( 480 _recommend_command( 481 "lt", 482 "List all tensors dumped during the failing run() call.", 483 create_link=True)) 484 else: 485 out.extend(debugger_cli_common.RichTextLines([ 486 "WARNING: Cannot determine the name of the op that caused the error."])) 487 488 more_lines = [ 489 "", 490 "Op name: %s" % op_name, 491 "Error type: " + str(type(tf_error)), 492 "", 493 "Details:", 494 str(tf_error), 495 "", 496 "--------------------------------------", 497 "", 498 ] 499 500 out.extend(debugger_cli_common.RichTextLines(more_lines)) 501 502 return out 503