1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Shared functions and classes for tfdbg command-line interface.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import math 21 22import numpy as np 23import six 24 25from tensorflow.python.debug.cli import command_parser 26from tensorflow.python.debug.cli import debugger_cli_common 27from tensorflow.python.debug.cli import tensor_format 28from tensorflow.python.debug.lib import common 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import variables 31from tensorflow.python.platform import gfile 32 33RL = debugger_cli_common.RichLine 34 35# Default threshold number of elements above which ellipses will be used 36# when printing the value of the tensor. 37DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000 38 39COLOR_BLACK = "black" 40COLOR_BLUE = "blue" 41COLOR_CYAN = "cyan" 42COLOR_GRAY = "gray" 43COLOR_GREEN = "green" 44COLOR_MAGENTA = "magenta" 45COLOR_RED = "red" 46COLOR_WHITE = "white" 47COLOR_YELLOW = "yellow" 48 49TIME_UNIT_US = "us" 50TIME_UNIT_MS = "ms" 51TIME_UNIT_S = "s" 52TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S] 53 54 55def bytes_to_readable_str(num_bytes, include_b=False): 56 """Generate a human-readable string representing number of bytes. 57 58 The units B, kB, MB and GB are used. 59 60 Args: 61 num_bytes: (`int` or None) Number of bytes. 62 include_b: (`bool`) Include the letter B at the end of the unit. 63 64 Returns: 65 (`str`) A string representing the number of bytes in a human-readable way, 66 including a unit at the end. 67 """ 68 69 if num_bytes is None: 70 return str(num_bytes) 71 if num_bytes < 1024: 72 result = "%d" % num_bytes 73 elif num_bytes < 1048576: 74 result = "%.2fk" % (num_bytes / 1024.0) 75 elif num_bytes < 1073741824: 76 result = "%.2fM" % (num_bytes / 1048576.0) 77 else: 78 result = "%.2fG" % (num_bytes / 1073741824.0) 79 80 if include_b: 81 result += "B" 82 return result 83 84 85def time_to_readable_str(value_us, force_time_unit=None): 86 """Convert time value to human-readable string. 87 88 Args: 89 value_us: time value in microseconds. 90 force_time_unit: force the output to use the specified time unit. Must be 91 in TIME_UNITS. 92 93 Returns: 94 Human-readable string representation of the time value. 95 96 Raises: 97 ValueError: if force_time_unit value is not in TIME_UNITS. 98 """ 99 if not value_us: 100 return "0" 101 if force_time_unit: 102 if force_time_unit not in TIME_UNITS: 103 raise ValueError("Invalid time unit: %s" % force_time_unit) 104 order = TIME_UNITS.index(force_time_unit) 105 time_unit = force_time_unit 106 return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 107 else: 108 order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3)) 109 time_unit = TIME_UNITS[order] 110 return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit) 111 112 113def parse_ranges_highlight(ranges_string): 114 """Process ranges highlight string. 115 116 Args: 117 ranges_string: (str) A string representing a numerical range of a list of 118 numerical ranges. See the help info of the -r flag of the print_tensor 119 command for more details. 120 121 Returns: 122 An instance of tensor_format.HighlightOptions, if range_string is a valid 123 representation of a range or a list of ranges. 124 """ 125 126 ranges = None 127 128 def ranges_filter(x): 129 r = np.zeros(x.shape, dtype=bool) 130 for range_start, range_end in ranges: 131 r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end)) 132 133 return r 134 135 if ranges_string: 136 ranges = command_parser.parse_ranges(ranges_string) 137 return tensor_format.HighlightOptions( 138 ranges_filter, description=ranges_string) 139 else: 140 return None 141 142 143def numpy_printoptions_from_screen_info(screen_info): 144 if screen_info and "cols" in screen_info: 145 return {"linewidth": screen_info["cols"]} 146 else: 147 return {} 148 149 150def format_tensor(tensor, 151 tensor_name, 152 np_printoptions, 153 print_all=False, 154 tensor_slicing=None, 155 highlight_options=None, 156 include_numeric_summary=False, 157 write_path=None): 158 """Generate formatted str to represent a tensor or its slices. 159 160 Args: 161 tensor: (numpy ndarray) The tensor value. 162 tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key. 163 np_printoptions: (dict) Numpy tensor formatting options. 164 print_all: (bool) Whether the tensor is to be displayed in its entirety, 165 instead of printing ellipses, even if its number of elements exceeds 166 the default numpy display threshold. 167 (Note: Even if this is set to true, the screen output can still be cut 168 off by the UI frontend if it consist of more lines than the frontend 169 can handle.) 170 tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If 171 None, no slicing will be performed on the tensor. 172 highlight_options: (tensor_format.HighlightOptions) options to highlight 173 elements of the tensor. See the doc of tensor_format.format_tensor() 174 for more details. 175 include_numeric_summary: Whether a text summary of the numeric values (if 176 applicable) will be included. 177 write_path: A path to save the tensor value (after any slicing) to 178 (optional). `numpy.save()` is used to save the value. 179 180 Returns: 181 An instance of `debugger_cli_common.RichTextLines` representing the 182 (potentially sliced) tensor. 183 """ 184 185 if tensor_slicing: 186 # Validate the indexing. 187 value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing) 188 sliced_name = tensor_name + tensor_slicing 189 else: 190 value = tensor 191 sliced_name = tensor_name 192 193 auxiliary_message = None 194 if write_path: 195 with gfile.Open(write_path, "wb") as output_file: 196 np.save(output_file, value) 197 line = debugger_cli_common.RichLine("Saved value to: ") 198 line += debugger_cli_common.RichLine(write_path, font_attr="bold") 199 line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length) 200 auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list( 201 [line, debugger_cli_common.RichLine("")]) 202 203 if print_all: 204 np_printoptions["threshold"] = value.size 205 else: 206 np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD 207 208 return tensor_format.format_tensor( 209 value, 210 sliced_name, 211 include_metadata=True, 212 include_numeric_summary=include_numeric_summary, 213 auxiliary_message=auxiliary_message, 214 np_printoptions=np_printoptions, 215 highlight_options=highlight_options) 216 217 218def error(msg): 219 """Generate a RichTextLines output for error. 220 221 Args: 222 msg: (str) The error message. 223 224 Returns: 225 (debugger_cli_common.RichTextLines) A representation of the error message 226 for screen output. 227 """ 228 229 return debugger_cli_common.rich_text_lines_from_rich_line_list([ 230 RL("ERROR: " + msg, COLOR_RED)]) 231 232 233def _recommend_command(command, description, indent=2, create_link=False): 234 """Generate a RichTextLines object that describes a recommended command. 235 236 Args: 237 command: (str) The command to recommend. 238 description: (str) A description of what the command does. 239 indent: (int) How many spaces to indent in the beginning. 240 create_link: (bool) Whether a command link is to be applied to the command 241 string. 242 243 Returns: 244 (RichTextLines) Formatted text (with font attributes) for recommending the 245 command. 246 """ 247 248 indent_str = " " * indent 249 250 if create_link: 251 font_attr = [debugger_cli_common.MenuItem("", command), "bold"] 252 else: 253 font_attr = "bold" 254 255 lines = [RL(indent_str) + RL(command, font_attr) + ":", 256 indent_str + " " + description] 257 258 return debugger_cli_common.rich_text_lines_from_rich_line_list(lines) 259 260 261def get_tfdbg_logo(): 262 """Make an ASCII representation of the tfdbg logo.""" 263 264 lines = [ 265 "", 266 "TTTTTT FFFF DDD BBBB GGG ", 267 " TT F D D B B G ", 268 " TT FFF D D BBBB G GG", 269 " TT F D D B B G G", 270 " TT F DDD BBBB GGG ", 271 "", 272 ] 273 return debugger_cli_common.RichTextLines(lines) 274 275 276_HORIZONTAL_BAR = "======================================" 277 278 279def get_run_start_intro(run_call_count, 280 fetches, 281 feed_dict, 282 tensor_filters, 283 is_callable_runner=False): 284 """Generate formatted intro for run-start UI. 285 286 Args: 287 run_call_count: (int) Run call counter. 288 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 289 for more details. 290 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 291 for more details. 292 tensor_filters: (dict) A dict from tensor-filter name to tensor-filter 293 callable. 294 is_callable_runner: (bool) whether a runner returned by 295 Session.make_callable is being run. 296 297 Returns: 298 (RichTextLines) Formatted intro message about the `Session.run()` call. 299 """ 300 301 fetch_lines = common.get_flattened_names(fetches) 302 303 if not feed_dict: 304 feed_dict_lines = [debugger_cli_common.RichLine(" (Empty)")] 305 else: 306 feed_dict_lines = [] 307 for feed_key in feed_dict: 308 feed_key_name = common.get_graph_element_name(feed_key) 309 feed_dict_line = debugger_cli_common.RichLine(" ") 310 feed_dict_line += debugger_cli_common.RichLine( 311 feed_key_name, 312 debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name)) 313 # Surround the name string with quotes, because feed_key_name may contain 314 # spaces in some cases, e.g., SparseTensors. 315 feed_dict_lines.append(feed_dict_line) 316 feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list( 317 feed_dict_lines) 318 319 out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR) 320 if is_callable_runner: 321 out.append("Running a runner returned by Session.make_callable()") 322 else: 323 out.append("Session.run() call #%d:" % run_call_count) 324 out.append("") 325 out.append("Fetch(es):") 326 out.extend(debugger_cli_common.RichTextLines( 327 [" " + line for line in fetch_lines])) 328 out.append("") 329 out.append("Feed dict:") 330 out.extend(feed_dict_lines) 331 out.append(_HORIZONTAL_BAR) 332 out.append("") 333 out.append("Select one of the following commands to proceed ---->") 334 335 out.extend( 336 _recommend_command( 337 "run", 338 "Execute the run() call with debug tensor-watching", 339 create_link=True)) 340 out.extend( 341 _recommend_command( 342 "run -n", 343 "Execute the run() call without debug tensor-watching", 344 create_link=True)) 345 out.extend( 346 _recommend_command( 347 "run -t <T>", 348 "Execute run() calls (T - 1) times without debugging, then " 349 "execute run() once more with debugging and drop back to the CLI")) 350 out.extend( 351 _recommend_command( 352 "run -f <filter_name>", 353 "Keep executing run() calls until a dumped tensor passes a given, " 354 "registered filter (conditional breakpoint mode)")) 355 356 more_lines = [" Registered filter(s):"] 357 if tensor_filters: 358 filter_names = [] 359 for filter_name in tensor_filters: 360 filter_names.append(filter_name) 361 command_menu_node = debugger_cli_common.MenuItem( 362 "", "run -f %s" % filter_name) 363 more_lines.append(RL(" * ") + RL(filter_name, command_menu_node)) 364 else: 365 more_lines.append(" (None)") 366 367 out.extend( 368 debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines)) 369 370 out.append("") 371 372 out.append_rich_line(RL("For more details, see ") + 373 RL("help.", debugger_cli_common.MenuItem("", "help")) + 374 ".") 375 out.append("") 376 377 # Make main menu for the run-start intro. 378 menu = debugger_cli_common.Menu() 379 menu.append(debugger_cli_common.MenuItem("run", "run")) 380 menu.append(debugger_cli_common.MenuItem("exit", "exit")) 381 out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu 382 383 return out 384 385 386def get_run_short_description(run_call_count, 387 fetches, 388 feed_dict, 389 is_callable_runner=False): 390 """Get a short description of the run() call. 391 392 Args: 393 run_call_count: (int) Run call counter. 394 fetches: Fetches of the `Session.run()` call. See doc of `Session.run()` 395 for more details. 396 feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()` 397 for more details. 398 is_callable_runner: (bool) whether a runner returned by 399 Session.make_callable is being run. 400 401 Returns: 402 (str) A short description of the run() call, including information about 403 the fetche(s) and feed(s). 404 """ 405 if is_callable_runner: 406 return "runner from make_callable()" 407 408 description = "run #%d: " % run_call_count 409 410 if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)): 411 description += "1 fetch (%s); " % common.get_graph_element_name(fetches) 412 else: 413 # Could be (nested) list, tuple, dict or namedtuple. 414 num_fetches = len(common.get_flattened_names(fetches)) 415 if num_fetches > 1: 416 description += "%d fetches; " % num_fetches 417 else: 418 description += "%d fetch; " % num_fetches 419 420 if not feed_dict: 421 description += "0 feeds" 422 else: 423 if len(feed_dict) == 1: 424 for key in feed_dict: 425 description += "1 feed (%s)" % ( 426 key if isinstance(key, six.string_types) or not hasattr(key, "name") 427 else key.name) 428 else: 429 description += "%d feeds" % len(feed_dict) 430 431 return description 432 433 434def get_error_intro(tf_error): 435 """Generate formatted intro for TensorFlow run-time error. 436 437 Args: 438 tf_error: (errors.OpError) TensorFlow run-time error object. 439 440 Returns: 441 (RichTextLines) Formatted intro message about the run-time OpError, with 442 sample commands for debugging. 443 """ 444 445 if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"): 446 op_name = tf_error.op.name 447 else: 448 op_name = None 449 450 intro_lines = [ 451 "--------------------------------------", 452 RL("!!! An error occurred during the run !!!", "blink"), 453 "", 454 ] 455 456 out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines) 457 458 if op_name is not None: 459 out.extend(debugger_cli_common.RichTextLines( 460 ["You may use the following commands to debug:"])) 461 out.extend( 462 _recommend_command("ni -a -d -t %s" % op_name, 463 "Inspect information about the failing op.", 464 create_link=True)) 465 out.extend( 466 _recommend_command("li -r %s" % op_name, 467 "List inputs to the failing op, recursively.", 468 create_link=True)) 469 470 out.extend( 471 _recommend_command( 472 "lt", 473 "List all tensors dumped during the failing run() call.", 474 create_link=True)) 475 else: 476 out.extend(debugger_cli_common.RichTextLines([ 477 "WARNING: Cannot determine the name of the op that caused the error."])) 478 479 more_lines = [ 480 "", 481 "Op name: %s" % op_name, 482 "Error type: " + str(type(tf_error)), 483 "", 484 "Details:", 485 str(tf_error), 486 "", 487 "--------------------------------------", 488 "", 489 ] 490 491 out.extend(debugger_cli_common.RichTextLines(more_lines)) 492 493 return out 494