1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Demo of the tfdbg curses UI: A TF network computing Fibonacci sequence.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import argparse 21import sys 22 23import numpy as np 24from six.moves import xrange # pylint: disable=redefined-builtin 25import tensorflow 26 27from tensorflow.python import debug as tf_debug 28 29tf = tensorflow.compat.v1 30 31FLAGS = None 32 33 34def main(_): 35 sess = tf.Session() 36 37 # Construct the TensorFlow network. 38 n0 = tf.Variable( 39 np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32, name="node_00") 40 n1 = tf.Variable( 41 np.ones([FLAGS.tensor_size] * 2), dtype=tf.int32, name="node_01") 42 43 for i in xrange(2, FLAGS.length): 44 n0, n1 = n1, tf.add(n0, n1, name="node_%.2d" % i) 45 46 sess.run(tf.global_variables_initializer()) 47 48 # Wrap the TensorFlow Session object for debugging. 49 if FLAGS.debug and FLAGS.tensorboard_debug_address: 50 raise ValueError( 51 "The --debug and --tensorboard_debug_address flags are mutually " 52 "exclusive.") 53 if FLAGS.debug: 54 sess = tf_debug.LocalCLIDebugWrapperSession(sess) 55 56 def has_negative(_, tensor): 57 return np.any(tensor < 0) 58 59 sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan) 60 sess.add_tensor_filter("has_negative", has_negative) 61 elif FLAGS.tensorboard_debug_address: 62 sess = tf_debug.TensorBoardDebugWrapperSession( 63 sess, FLAGS.tensorboard_debug_address) 64 65 print("Fibonacci number at position %d:\n%s" % (FLAGS.length, sess.run(n1))) 66 67 68if __name__ == "__main__": 69 parser = argparse.ArgumentParser() 70 parser.register("type", "bool", lambda v: v.lower() == "true") 71 parser.add_argument( 72 "--tensor_size", 73 type=int, 74 default=1, 75 help="""\ 76 Size of tensor. E.g., if the value is 30, the tensors will have shape 77 [30, 30].\ 78 """) 79 parser.add_argument( 80 "--length", 81 type=int, 82 default=20, 83 help="Length of the fibonacci sequence to compute.") 84 parser.add_argument( 85 "--ui_type", 86 type=str, 87 default="curses", 88 help="Command-line user interface type (curses | readline)") 89 parser.add_argument( 90 "--debug", 91 dest="debug", 92 action="store_true", 93 help="Use TensorFlow Debugger (tfdbg). Mutually exclusive with the " 94 "--tensorboard_debug_address flag.") 95 parser.add_argument( 96 "--tensorboard_debug_address", 97 type=str, 98 default=None, 99 help="Connect to the TensorBoard Debugger Plugin backend specified by " 100 "the gRPC address (e.g., localhost:1234). Mutually exclusive with the " 101 "--debug flag.") 102 103 FLAGS, unparsed = parser.parse_known_args() 104 with tf.Graph().as_default(): 105 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 106