• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tfdbg example: debugging tf.keras models training on tf.data.Dataset."""
16
17import argparse
18import sys
19import tempfile
20
21import numpy as np
22import tensorflow
23
24from tensorflow.python import debug as tf_debug
25
26tf = tensorflow.compat.v1
27
28
29def main(_):
30  # Create a dummy dataset.
31  num_examples = 8
32  steps_per_epoch = 2
33  input_dims = 3
34  output_dims = 1
35  xs = np.zeros([num_examples, input_dims])
36  ys = np.zeros([num_examples, output_dims])
37  dataset = tf.data.Dataset.from_tensor_slices(
38      (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch))
39
40  sess = tf.Session()
41  if FLAGS.debug:
42    # Use the command-line interface (CLI) of tfdbg.
43    if FLAGS.use_random_config_path:
44      _, config_file_path = tempfile.mkstemp(".tfdbg_config")
45    else:
46      config_file_path = None
47    sess = tf_debug.LocalCLIDebugWrapperSession(
48        sess, ui_type=FLAGS.ui_type, config_file_path=config_file_path)
49  elif FLAGS.tensorboard_debug_address:
50    # Use the TensorBoard Debugger Plugin (GUI of tfdbg).
51    sess = tf_debug.TensorBoardDebugWrapperSession(
52        sess, FLAGS.tensorboard_debug_address)
53  tf.keras.backend.set_session(sess)
54
55  # Create a dummy model.
56  model = tf.keras.Sequential(
57      [tf.keras.layers.Dense(1, input_shape=[input_dims])])
58  model.compile(loss="mse", optimizer="sgd")
59
60  # Train the model using the dummy dataset created above.
61  model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch)
62
63
64if __name__ == "__main__":
65  parser = argparse.ArgumentParser()
66  parser.register("type", "bool", lambda v: v.lower() == "true")
67  parser.add_argument(
68      "--debug",
69      type="bool",
70      nargs="?",
71      const=True,
72      default=False,
73      help="Use debugger to track down bad values during training. "
74      "Mutually exclusive with the --tensorboard_debug_address flag.")
75  parser.add_argument(
76      "--ui_type",
77      type=str,
78      default="curses",
79      help="Command-line user interface type (curses | readline).")
80  parser.add_argument(
81      "--use_random_config_path",
82      type="bool",
83      nargs="?",
84      const=True,
85      default=False,
86      help="""If set, set config file path to a random file in the temporary
87      directory.""")
88  parser.add_argument(
89      "--tensorboard_debug_address",
90      type=str,
91      default=None,
92      help="Connect to the TensorBoard Debugger Plugin backend specified by "
93      "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
94      "--debug flag.")
95  parser.add_argument(
96      "--epochs",
97      type=int,
98      default=2,
99      help="Number of epochs to train the model for.")
100  FLAGS, unparsed = parser.parse_known_args()
101  with tf.Graph().as_default():
102    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
103