1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""Tensor Tracer report generation utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import hashlib 23import os 24 25 26from tensorflow.python.platform import gfile 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.tpu import tensor_tracer_pb2 29 30_TRACER_LOG_PREFIX = ' [>>>TT>>>]' 31_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' 32_MARKER_SECTION_END = '!!!!!!! section-end:' 33 34_SECTION_NAME_CONFIG = 'configuration' 35_SECTION_NAME_REASON = 'reason' 36_SECTION_NAME_OP_LIST = 'op-list' 37_SECTION_NAME_TENSOR_LIST = 'tensor-list' 38_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map' 39_SECTION_NAME_GRAPH = 'graph' 40_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' 41 42_FIELD_NAME_VERSION = 'version:' 43_FIELD_NAME_DEVICE = 'device:' 44_FIELD_NAME_TRACE_MODE = 'trace-mode:' 45_FIELD_NAME_SUBMODE = 'submode:' 46_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' 47_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:' 48_FIELD_NAME_NUM_HOSTS = 'num-hosts:' 49_FIELD_NAME_NUM_OPS = 'number-of-ops:' 50_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' 51_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:' 52_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' 53 54_CURRENT_VERSION = 'use-outside-compilation' 55_TT_REPORT_PROTO = 'tensor_tracer_report.report_pb' 56 57 58def report_proto_path(trace_dir): 59 """Returns the path where report proto should be written. 60 61 Args: 62 trace_dir: String denoting the trace directory. 63 64 Returns: 65 A string denoting the path to the report proto. 66 """ 67 return os.path.join(trace_dir, _TT_REPORT_PROTO) 68 69 70def topological_sort(g): 71 """Performs topological sort on the given graph. 72 73 Args: 74 g: the graph. 75 76 Returns: 77 A pair where the first element indicates if the topological 78 sort succeeded (True if there is no cycle found; False if a 79 cycle is found) and the second element is either the sorted 80 list of nodes or the cycle of nodes found. 81 """ 82 def _is_loop_edge(op): 83 """Returns true if the op is the end of a while-loop creating a cycle.""" 84 return op.type in ['NextIteration'] 85 86 def _in_op_degree(op): 87 """Returns the number of incoming edges to the given op. 88 89 The edge calculation skips the edges that come from 'NextIteration' ops. 90 NextIteration creates a cycle in the graph. We break cycles by treating 91 this op as 'sink' and ignoring all outgoing edges from it. 92 Args: 93 op: Tf.Operation 94 Returns: 95 the number of incoming edges. 96 """ 97 count = 0 98 for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]: 99 if not _is_loop_edge(op): 100 count += 1 101 return count 102 103 sorted_ops = [] 104 op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()} 105 106 frontier = [op for (op, degree) in op_in_degree.items() if degree == 0] 107 frontier.sort(key=lambda op: op.name) 108 while frontier: 109 op = frontier.pop() 110 # Remove the op from graph, and remove its outgoing edges. 111 sorted_ops.append(op) 112 if _is_loop_edge(op): 113 continue 114 # pylint: disable=protected-access 115 consumers = list(op._control_outputs) 116 # pylint: enable=protected-access 117 for out_tensor in op.outputs: 118 consumers += [consumer_op for consumer_op in out_tensor.consumers()] 119 consumers.sort(key=lambda op: op.name) 120 for consumer in consumers: 121 # For each deleted edge shift the bucket of the vertex. 122 op_in_degree[consumer] -= 1 123 if op_in_degree[consumer] == 0: 124 frontier.append(consumer) 125 if op_in_degree[consumer] < 0: 126 raise ValueError('consumer:%s degree mismatch'%consumer.name) 127 128 left_ops = set(op for (op, degree) in op_in_degree.items() if degree > 0) 129 if left_ops: 130 return (True, left_ops) 131 else: 132 assert len(g.get_operations()) == len(sorted_ops) 133 return (False, sorted_ops) 134 135 136class TensorTracerConfig(object): 137 """Tensor Tracer config object.""" 138 139 def __init__(self): 140 self.version = _CURRENT_VERSION 141 self.device_type = None 142 self.num_replicas = None 143 self.num_replicas_per_host = None 144 self.num_hosts = None 145 146 147class TensorTraceOrder(object): 148 """Class that is responsible from storing the trace-id of the tensors.""" 149 150 def __init__(self, graph_order, traced_tensors): 151 self.graph_order = graph_order 152 self.traced_tensors = traced_tensors 153 self._create_tensor_maps() 154 155 def _create_tensor_maps(self): 156 """Creates tensor to cache id maps.""" 157 self.tensorname_to_cache_idx = {} 158 self.cache_idx_to_tensor_idx = [] 159 for out_tensor in self.traced_tensors: 160 tensor_name = out_tensor.name 161 if tensor_name in self.tensorname_to_cache_idx: 162 raise ValueError( 163 'Tensor name %s should not be already in ' 164 'tensorname_to_cache_idx'%tensor_name) 165 if tensor_name not in self.graph_order.tensor_to_idx: 166 raise ValueError( 167 'Tensor name %s is not in the tensor_to_idx'%tensor_name) 168 tensor_idx = self.graph_order.tensor_to_idx[tensor_name] 169 cache_idx = len(self.tensorname_to_cache_idx) 170 self.tensorname_to_cache_idx[tensor_name] = cache_idx 171 self.cache_idx_to_tensor_idx.append(tensor_idx) 172 if len(self.tensorname_to_cache_idx) != len( 173 self.cache_idx_to_tensor_idx): 174 raise RuntimeError('len(self.tensorname_to_cache_idx) != ' 175 'len(self.cache_idx_to_tensor_idx') 176 177 178def sort_tensors_and_ops(graph): 179 """Returns a wrapper that has consistent tensor and op orders.""" 180 graph_wrapper = collections.namedtuple('GraphWrapper', 181 ['graph', 'operations', 'op_to_idx', 182 'tensors', 'tensor_to_idx', 183 'contains_cycle', 184 'topological_order_or_cycle']) 185 contains_cycle, topological_order_or_cycle = topological_sort(graph) 186 if not contains_cycle: 187 operations = topological_order_or_cycle 188 else: 189 operations = graph.get_operations() 190 op_to_idx = {op.name: index for index, op 191 in enumerate(operations)} 192 tensors = [] 193 for op in operations: 194 tensors.extend(op.outputs) 195 tensor_to_idx = {tensor.name: index for index, tensor in 196 enumerate(tensors)} 197 return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx, 198 tensors=tensors, tensor_to_idx=tensor_to_idx, 199 contains_cycle=contains_cycle, 200 topological_order_or_cycle=topological_order_or_cycle) 201 202 203class OpenReportFile(object): 204 """Context manager for writing report file.""" 205 206 def __init__(self, tt_parameters): 207 if not tt_parameters.report_file_path: 208 self._report_file = None 209 return 210 try: 211 self._report_file = gfile.Open(tt_parameters.report_file_path, 'w') 212 except IOError as e: 213 raise e 214 215 def __enter__(self): 216 return self._report_file 217 218 def __exit__(self, unused_type, unused_value, unused_traceback): 219 if self._report_file: 220 self._report_file.close() 221 222 223def proto_fingerprint(message_proto): 224 serialized_message = message_proto.SerializeToString() 225 hasher = hashlib.sha256(serialized_message) 226 return hasher.hexdigest() 227 228 229class TTReportHandle(object): 230 """Utility class responsible from creating a tensor tracer report.""" 231 232 def __init__(self): 233 self.instrument_records = {} 234 self._report_file = None 235 236 def instrument(self, name, explanation): 237 self.instrument_records[name] = explanation 238 239 def instrument_op(self, op, explanation): 240 self.instrument(op.name, explanation) 241 242 def instrument_tensor(self, tensor, explanation): 243 self.instrument(tensor.name, explanation) 244 245 def create_report_proto(self, tt_config, tt_parameters, tensor_trace_order, 246 tensor_trace_points, collected_signature_types): 247 """Creates and returns a proto that stores tensor tracer configuration. 248 249 Args: 250 tt_config: TensorTracerConfig object holding information about the run 251 environment (device, # cores, # hosts), and tensor tracer version 252 information. 253 tt_parameters: TTParameters objects storing the user provided parameters 254 for tensor tracer. 255 tensor_trace_order: TensorTraceOrder object storing a topological order of 256 the graph. 257 tensor_trace_points: Progromatically added trace_points/checkpoints. 258 collected_signature_types: The signature types collected, e,g, norm, 259 max, min, mean... 260 Returns: 261 TensorTracerReport proto. 262 """ 263 report = tensor_tracer_pb2.TensorTracerReport() 264 report.config.version = tt_config.version 265 report.config.device = tt_config.device_type 266 report.config.num_cores = tt_config.num_replicas 267 report.config.num_hosts = tt_config.num_hosts 268 report.config.num_cores_per_host = tt_config.num_replicas_per_host 269 report.config.submode = tt_parameters.submode 270 report.config.trace_mode = tt_parameters.trace_mode 271 272 for signature_name, _ in sorted(collected_signature_types.items(), 273 key=lambda x: x[1]): 274 report.config.signatures.append(signature_name) 275 276 for tensor in tensor_trace_order.graph_order.tensors: 277 tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef() 278 tensor_def.name = tensor.name 279 if tensor.name in tensor_trace_order.tensorname_to_cache_idx: 280 tensor_def.is_traced = True 281 tensor_def.cache_index = ( 282 tensor_trace_order.tensorname_to_cache_idx[tensor.name]) 283 else: 284 # To prevent small changes affecting the fingerprint calculation, avoid 285 # writing the untraced tensors to metadata. Fingerprints will be 286 # different only when the list of the traced tensors are different. 287 if tt_parameters.use_fingerprint_subdir: 288 continue 289 tensor_def.is_traced = False 290 291 if tensor.name in tensor_trace_points: 292 tensor_def.trace_point_name = tensor_trace_points[tensor.name] 293 if tensor.name in self.instrument_records: 294 tensor_def.explanation = self.instrument_records[tensor.name] 295 elif tensor.op.name in self.instrument_records: 296 tensor_def.explanation = self.instrument_records[tensor.op.name] 297 report.tensordef[tensor.name].CopyFrom(tensor_def) 298 report.fingerprint = proto_fingerprint(report) 299 logging.info('TensorTracerProto fingerprint is %s.', 300 report.fingerprint) 301 tf_graph = tensor_trace_order.graph_order.graph 302 report.graphdef.CopyFrom(tf_graph.as_graph_def()) 303 return report 304 305 def write_report_proto(self, report_proto, tt_parameters): 306 """Writes the given report proto under trace_dir.""" 307 gfile.MakeDirs(tt_parameters.trace_dir) 308 report_path = report_proto_path(tt_parameters.trace_dir) 309 with gfile.GFile(report_path, 'wb') as f: 310 f.write(report_proto.SerializeToString()) 311 312 def create_report(self, tt_config, tt_parameters, 313 tensor_trace_order, tensor_trace_points): 314 """Creates a report file and writes the trace information.""" 315 with OpenReportFile(tt_parameters) as self._report_file: 316 self._write_config_section(tt_config, tt_parameters) 317 self._write_op_list_section(tensor_trace_order.graph_order) 318 self._write_tensor_list_section(tensor_trace_order.graph_order) 319 self._write_trace_points(tensor_trace_points) 320 self._write_cache_index_map_section(tensor_trace_order) 321 self._write_reason_section() 322 self._write_graph_section(tensor_trace_order.graph_order) 323 324 def _write_trace_points(self, tensor_trace_points): 325 """Writes the list of checkpoints.""" 326 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 327 _SECTION_NAME_TENSOR_TRACER_CHECKPOINT)) 328 for (tensor, checkpoint_name) in tensor_trace_points: 329 self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) 330 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 331 _SECTION_NAME_TENSOR_TRACER_CHECKPOINT)) 332 333 def _write_report(self, content): 334 """Writes the given content to the report.""" 335 336 line = '%s %s'%(_TRACER_LOG_PREFIX, content) 337 if self._report_file: 338 self._report_file.write(line) 339 else: 340 logging.info(line) 341 342 def _write_config_section(self, tt_config, tt_parameters): 343 """Writes the config section of the report.""" 344 345 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) 346 self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version)) 347 self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type)) 348 self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, 349 tt_parameters.trace_mode)) 350 self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, 351 tt_parameters.submode)) 352 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, 353 tt_config.num_replicas)) 354 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST, 355 tt_config.num_replicas_per_host)) 356 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts)) 357 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) 358 359 def _write_reason_section(self): 360 """Writes the reason section of the report.""" 361 362 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) 363 for key in sorted(self.instrument_records): 364 self._write_report('"%s" %s\n'%(key, self.instrument_records[key])) 365 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) 366 367 def _write_op_list_section(self, graph_order): 368 """Writes the Op-list section of the report.""" 369 370 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) 371 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, 372 len(graph_order.operations))) 373 for i in range(0, len(graph_order.operations)): 374 op = graph_order.operations[i] 375 line = '%d "%s" %s'%(i, op.name, op.type) 376 for out_tensor in op.outputs: 377 if out_tensor.name not in graph_order.tensor_to_idx: 378 raise ValueError( 379 'out_tensor %s is not in tensor_to_idx'%out_tensor.name) 380 line += ' %d'%graph_order.tensor_to_idx[out_tensor.name] 381 line += '\n' 382 self._write_report(line) 383 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) 384 385 def _write_tensor_list_section(self, graph_order): 386 """Writes the tensor-list section of the report.""" 387 388 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 389 _SECTION_NAME_TENSOR_LIST)) 390 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, 391 len(graph_order.tensors))) 392 for i in range(0, len(graph_order.tensors)): 393 tensor = graph_order.tensors[i] 394 line = '%d "%s"'%(i, tensor.name) 395 consumers = tensor.consumers() 396 consumers.sort(key=lambda op: op.name) 397 for consumer_op in consumers: 398 if consumer_op.name not in graph_order.op_to_idx: 399 raise ValueError( 400 'consumer_op %s is not in op_to_idx'%consumer_op.name) 401 line += ' %d'%graph_order.op_to_idx[consumer_op.name] 402 line += '\n' 403 self._write_report(line) 404 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 405 _SECTION_NAME_TENSOR_LIST)) 406 407 def _write_cache_index_map_section(self, tensor_trace_order): 408 """Writes the mapping from cache index to tensor index to the report.""" 409 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 410 _SECTION_NAME_CACHE_INDEX_MAP)) 411 self._write_report('%s %d\n'%( 412 _FIELD_NAME_NUM_CACHE_INDICES, 413 len(tensor_trace_order.cache_idx_to_tensor_idx))) 414 for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)): 415 tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx] 416 line = '%d %d\n'%(cache_idx, tensor_idx) 417 self._write_report(line) 418 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 419 _SECTION_NAME_CACHE_INDEX_MAP)) 420 421 def _write_graph_section(self, graph_order): 422 """Writes the graph section of the report.""" 423 424 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) 425 self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, 426 not graph_order.contains_cycle)) 427 l = list(graph_order.topological_order_or_cycle) 428 for i in range(0, len(l)): 429 self._write_report('%d "%s"\n'%(i, l[i].name)) 430 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) 431