1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Cloud TPU profiler client.""" 16 17import os 18import sys 19 20from absl import app 21from absl import flags 22from distutils.version import LooseVersion 23 24from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver 25from tensorflow.python.profiler import profiler_client 26from tensorflow.python.profiler import profiler_v2 as profiler 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import versions 29from tensorflow.python.platform import gfile 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.tpu.profiler import version as profiler_version 32 33FLAGS = flags.FLAGS 34 35# Cloud TPU Cluster Resolvers 36flags.DEFINE_string( 37 'gcp_project', None, 38 'Project name for the Cloud TPU-enabled project. If not specified, we ' 39 'will attempt to automatically detect the GCE project from metadata.') 40flags.DEFINE_string( 41 'tpu_zone', 42 None, 43 help='GCE zone where the Cloud TPU is located in. If not specified, we ' 44 'will attempt to automatically detect the GCE project from metadata.') 45flags.DEFINE_string( 46 'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must ' 47 'specify either this flag or --service_addr.') 48 49# Tool specific parameters 50flags.DEFINE_string( 51 'service_addr', None, 'Address of TPU profiler service e.g. ' 52 'localhost:8466, you must specify either this flag or --tpu.') 53flags.DEFINE_string( 54 'workers_list', None, 'The list of worker TPUs that we are about to profile' 55 ' e.g. 10.0.1.2:8466, 10.0.1.3:8466. You can specify this flag with --tpu ' 56 'or --service_addr to profile a subset of tpu nodes. You can also use only' 57 '--tpu and leave this flag unspecified to profile all the tpus.') 58flags.DEFINE_string( 59 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' 60 'gs://tb_bucket') 61flags.DEFINE_integer('duration_ms', 0, 62 'Duration of tracing or monitoring in ms.') 63flags.DEFINE_integer( 64 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' 65 'event is collected.') 66flags.DEFINE_boolean('include_dataset_ops', True, 'Deprecated.') 67flags.DEFINE_integer( 68 'host_tracer_level', 2, 'Adjust host tracer level to control the verbosity ' 69 ' of the TraceMe event being collected.') 70 71# Monitoring parameters 72flags.DEFINE_integer( 73 'monitoring_level', 0, 'Choose a monitoring level between ' 74 '1 and 2 to monitor your TPU job continuously. Level 2 is more verbose than' 75 ' level 1 and shows more metrics.') 76flags.DEFINE_integer( 77 'num_queries', 100, 78 'This script will run monitoring for num_queries before it stops.') 79flags.DEFINE_boolean('display_timestamp', True, 'Deprecated.') 80 81 82def get_workers_list(cluster_resolver): 83 """Returns a comma separated list of TPU worker host:port pairs. 84 85 Gets cluster_spec from cluster_resolver. Use the worker's task indices to 86 obtain and return a list of host:port pairs. 87 88 Args: 89 cluster_resolver: TensorFlow TPUClusterResolver instance. 90 91 Returns: 92 A string of comma separated list of host:port pairs. For example: 93 '10.2.0.1:8466,10.2.0.2:8466,10.2.0.3:8466,10.2.0.4:8466' 94 95 Raises: 96 UnavailableError: cluster_resolver doesn't contain a valid cluster_spec. 97 """ 98 worker_job_name = 'worker' 99 cluster_spec = cluster_resolver.cluster_spec() 100 if not cluster_spec: 101 raise errors.UnavailableError( 102 'None', 'None', 103 'Cluster spec not found, your client must run in GCE environment.') 104 task_indices = cluster_spec.task_indices(worker_job_name) 105 workers_list = [ 106 cluster_spec.task_address(worker_job_name, i).replace(':8470', ':8466') 107 for i in task_indices 108 ] 109 return ','.join(workers_list) 110 111 112def monitoring_helper(service_addr, duration_ms, monitoring_level, num_queries): 113 """Helper function to print monitoring results. 114 115 Helper function to print monitoring results for num_queries times. 116 117 Args: 118 service_addr: Address of the TPU profiler service. 119 duration_ms: Duration of one monitoring sample in milliseconds. 120 monitoring_level: An integer between 1 and 2. Level 2 is more verbose than 121 level 1 and shows more metrics. 122 num_queries: Number of monitoring samples to collect. 123 """ 124 if monitoring_level <= 0 or monitoring_level > 2: 125 sys.exit('Please choose a monitoring level between 1 and 2.') 126 127 for query in range(0, num_queries): 128 res = profiler_client.monitor(service_addr, duration_ms, monitoring_level) 129 print('Cloud TPU Monitoring Results (Sample ', query, '):\n\n', res) 130 131 132def run_main(): 133 app.run(main) 134 135 136def main(unused_argv=None): 137 logging.set_verbosity(logging.INFO) 138 tf_version = versions.__version__ 139 print('TensorFlow version %s detected' % tf_version) 140 print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__) 141 142 if LooseVersion(tf_version) < LooseVersion('2.2.0'): 143 sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.') 144 145 if not FLAGS.service_addr and not FLAGS.tpu: 146 sys.exit('You must specify either --service_addr or --tpu.') 147 148 tpu_cluster_resolver = None 149 if FLAGS.service_addr: 150 if FLAGS.tpu: 151 logging.warn('Both --service_addr and --tpu are set. Ignoring ' 152 '--tpu and using --service_addr.') 153 service_addr = FLAGS.service_addr 154 else: 155 try: 156 tpu_cluster_resolver = ( 157 resolver.TPUClusterResolver([FLAGS.tpu], 158 zone=FLAGS.tpu_zone, 159 project=FLAGS.gcp_project)) 160 service_addr = tpu_cluster_resolver.get_master() 161 except (ValueError, TypeError): 162 sys.exit('Failed to find TPU %s in zone %s project %s. You may use ' 163 '--tpu_zone and --gcp_project to specify the zone and project of' 164 ' your TPU.' % (FLAGS.tpu, FLAGS.tpu_zone, FLAGS.gcp_project)) 165 service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') 166 167 workers_list = '' 168 if FLAGS.workers_list is not None: 169 workers_list = FLAGS.workers_list 170 elif tpu_cluster_resolver is not None: 171 workers_list = get_workers_list(tpu_cluster_resolver) 172 173 # If profiling duration was not set by user or set to a non-positive value, 174 # we set it to a default value of 1000ms. 175 duration_ms = FLAGS.duration_ms if FLAGS.duration_ms > 0 else 1000 176 177 if FLAGS.monitoring_level > 0: 178 print('Since monitoring level is provided, profile', service_addr, ' for ', 179 FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries, 180 ' time(s).') 181 monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level, 182 FLAGS.num_queries) 183 else: 184 if not FLAGS.logdir: 185 sys.exit('You must specify either --logdir or --monitoring_level.') 186 187 if not gfile.Exists(FLAGS.logdir): 188 gfile.MakeDirs(FLAGS.logdir) 189 190 try: 191 if LooseVersion(tf_version) < LooseVersion('2.3.0'): 192 profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), 193 duration_ms, workers_list, 194 FLAGS.num_tracing_attempts) 195 else: 196 options = profiler.ProfilerOptions( 197 host_tracer_level=FLAGS.host_tracer_level) 198 profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), 199 duration_ms, workers_list, 200 FLAGS.num_tracing_attempts, options) 201 except errors.UnavailableError: 202 sys.exit(0) 203 204 205if __name__ == '__main__': 206 run_main() 207