• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Cloud TPU profiler client."""
16
17import os
18import sys
19
20from absl import app
21from absl import flags
22from distutils.version import LooseVersion
23
24from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver
25from tensorflow.python.profiler import profiler_client
26from tensorflow.python.profiler import profiler_v2 as profiler
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import versions
29from tensorflow.python.platform import gfile
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.tpu.profiler import version as profiler_version
32
33FLAGS = flags.FLAGS
34
35# Cloud TPU Cluster Resolvers
36flags.DEFINE_string(
37    'gcp_project', None,
38    'Project name for the Cloud TPU-enabled project. If not specified, we '
39    'will attempt to automatically detect the GCE project from metadata.')
40flags.DEFINE_string(
41    'tpu_zone',
42    None,
43    help='GCE zone where the Cloud TPU is located in. If not specified, we '
44    'will attempt to automatically detect the GCE project from metadata.')
45flags.DEFINE_string(
46    'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must '
47    'specify either this flag or --service_addr.')
48
49# Tool specific parameters
50flags.DEFINE_string(
51    'service_addr', None, 'Address of TPU profiler service e.g. '
52    'localhost:8466, you must specify either this flag or --tpu.')
53flags.DEFINE_string(
54    'workers_list', None, 'The list of worker TPUs that we are about to profile'
55    ' e.g. 10.0.1.2:8466, 10.0.1.3:8466. You can specify this flag with --tpu '
56    'or --service_addr to profile a subset of tpu nodes. You can also use only'
57    '--tpu and leave this flag unspecified to profile all the tpus.')
58flags.DEFINE_string(
59    'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
60    'gs://tb_bucket')
61flags.DEFINE_integer('duration_ms', 0,
62                     'Duration of tracing or monitoring in ms.')
63flags.DEFINE_integer(
64    'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
65    'event is collected.')
66flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.')
67flags.DEFINE_integer(
68    'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity '
69    ' of the TraceMe event being collected.')
70
71# Monitoring parameters
72flags.DEFINE_integer(
73    'monitoring_level', 0, 'Choose a monitoring level between '
74    '1 and 2 to monitor your TPU job continuously. Level 2 is more verbose than'
75    ' level 1 and shows more metrics.')
76flags.DEFINE_integer(
77    'num_queries', 100,
78    'This script will run monitoring for num_queries before it stops.')
79flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.')
80
81
82def get_workers_list(cluster_resolver):
83  """Returns a comma separated list of TPU worker host:port pairs.
84
85  Gets cluster_spec from cluster_resolver. Use the worker's task indices to
86  obtain and return a list of host:port pairs.
87
88  Args:
89    cluster_resolver: TensorFlow TPUClusterResolver instance.
90
91  Returns:
92    A string of comma separated list of host:port pairs. For example:
93    '10.2.0.1:8466,10.2.0.2:8466,10.2.0.3:8466,10.2.0.4:8466'
94
95  Raises:
96    UnavailableError: cluster_resolver doesn't contain a valid cluster_spec.
97  """
98  worker_job_name = 'worker'
99  cluster_spec = cluster_resolver.cluster_spec()
100  if not cluster_spec:
101    raise errors.UnavailableError(
102        'None', 'None',
103        'Cluster spec not found, your client must run in GCE environment.')
104  task_indices = cluster_spec.task_indices(worker_job_name)
105  workers_list = [
106      cluster_spec.task_address(worker_job_name, i).replace(':8470', ':8466')
107      for i in task_indices
108  ]
109  return ','.join(workers_list)
110
111
112def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries):
113  """Helper function to print monitoring results.
114
115  Helper function to print monitoring results for num_queries times.
116
117  Args:
118    service_addr: Address of the TPU profiler service.
119    duration_ms: Duration of one monitoring sample in milliseconds.
120    monitoring_level: An integer between 1 and 2. Level 2 is more verbose than
121      level 1 and shows more metrics.
122    num_queries: Number of monitoring samples to collect.
123  """
124  if monitoring_level <= 0 or monitoring_level > 2:
125    sys.exit('Please choose a monitoring level between 1 and 2.')
126
127  for query in range(0, num_queries):
128    res = profiler_client.monitor(service_addr, duration_ms, monitoring_level)
129    print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res)
130
131
132def run_main():
133  app.run(main)
134
135
136def main(unused_argv=None):
137  logging.set_verbosity(logging.INFO)
138  tf_version = versions.__version__
139  print('TensorFlow version %s detected' % tf_version)
140  print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__)
141
142  if LooseVersion(tf_version) < LooseVersion('2.2.0'):
143    sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.')
144
145  if not FLAGS.service_addr and not FLAGS.tpu:
146    sys.exit('You must specify either --service_addr or --tpu.')
147
148  tpu_cluster_resolver = None
149  if FLAGS.service_addr:
150    if FLAGS.tpu:
151      logging.warn('Both --service_addr and --tpu are set. Ignoring '
152                   '--tpu and using --service_addr.')
153    service_addr = FLAGS.service_addr
154  else:
155    try:
156      tpu_cluster_resolver = (
157          resolver.TPUClusterResolver([FLAGS.tpu],
158                                      zone=FLAGS.tpu_zone,
159                                      project=FLAGS.gcp_project))
160      service_addr = tpu_cluster_resolver.get_master()
161    except (ValueError, TypeError):
162      sys.exit('Failed to find TPU %s in zone %s project %s. You may use '
163               '--tpu_zone and --gcp_project to specify the zone and project of'
164               ' your TPU.' % (FLAGS.tpu, FLAGS.tpu_zone, FLAGS.gcp_project))
165  service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466')
166
167  workers_list = ''
168  if FLAGS.workers_list is not None:
169    workers_list = FLAGS.workers_list
170  elif tpu_cluster_resolver is not None:
171    workers_list = get_workers_list(tpu_cluster_resolver)
172
173  # If profiling duration was not set by user or set to a non-positive value,
174  # we set it to a default value of 1000ms.
175  duration_ms = FLAGS.duration_ms if FLAGS.duration_ms > 0 else 1000
176
177  if FLAGS.monitoring_level > 0:
178    print('Since monitoring level is provided, profile', service_addr, ' for ',
179          FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries,
180          ' time(s).')
181    monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level,
182                      FLAGS.num_queries)
183  else:
184    if not FLAGS.logdir:
185      sys.exit('You must specify either --logdir or --monitoring_level.')
186
187    if not gfile.Exists(FLAGS.logdir):
188      gfile.MakeDirs(FLAGS.logdir)
189
190    try:
191      if LooseVersion(tf_version) < LooseVersion('2.3.0'):
192        profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
193                              duration_ms, workers_list,
194                              FLAGS.num_tracing_attempts)
195      else:
196        options = profiler.ProfilerOptions(
197            host_tracer_level=FLAGS.host_tracer_level)
198        profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir),
199                              duration_ms, workers_list,
200                              FLAGS.num_tracing_attempts, options)
201    except errors.UnavailableError:
202      sys.exit(0)
203
204
205if __name__ == '__main__':
206  run_main()
207