1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import io 23import itertools 24import threading 25 26from absl import app 27 28from tensorflow.python.compat import v2_compat 29from tensorflow.python.distribute import collective_all_reduce_strategy 30from tensorflow.python.distribute import multi_process_runner 31from tensorflow.python.distribute import tpu_strategy 32from tensorflow.python.distribute import values 33from tensorflow.python.eager import context 34from tensorflow.python.framework import config 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.util import nest 38 39try: 40 import objgraph # pylint:disable=g-import-not-at-top 41except ImportError: 42 objgraph = None 43 44 45def gather(strategy, value): 46 """Gathers value from all workers. 47 48 This is intended for tests before we implement an official all-gather API. 49 50 Args: 51 strategy: a `tf.distribute.Strategy`. 52 value: a nested structure of n-dim `tf.distribute.DistributedValue` of 53 `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica. 54 Cannot contain tf.sparse.SparseTensor. 55 56 Returns: 57 a (n+1)-dim `tf.Tensor`. 58 """ 59 return nest.map_structure(functools.partial(_gather, strategy), value) 60 61 62def _gather(strategy, value): 63 """Gathers a single value.""" 64 # pylint: disable=protected-access 65 if not isinstance(value, values.DistributedValues): 66 value = values.PerReplica([ops.convert_to_tensor(value)]) 67 if not isinstance(strategy.extended, 68 collective_all_reduce_strategy.CollectiveAllReduceExtended): 69 return array_ops.stack(value._values) 70 assert len(strategy.extended.worker_devices) == len(value._values) 71 inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] 72 return strategy.gather(values.PerReplica(inputs), axis=0) 73 # pylint: enable=protected-access 74 75 76def set_logical_devices_to_at_least(device, num): 77 """Create logical devices of at least a given number.""" 78 if num < 1: 79 raise ValueError("`num` must be at least 1 not %r" % (num,)) 80 physical_devices = config.list_physical_devices(device) 81 if not physical_devices: 82 raise RuntimeError("No {} found".format(device)) 83 if len(physical_devices) >= num: 84 return 85 # By default each physical device corresponds to one logical device. We create 86 # multiple logical devices for the last physical device so that we have `num` 87 # logical devices. 88 num = num - len(physical_devices) + 1 89 logical_devices = [] 90 for _ in range(num): 91 if device.upper() == "GPU": 92 logical_devices.append( 93 context.LogicalDeviceConfiguration(memory_limit=2048)) 94 else: 95 logical_devices.append(context.LogicalDeviceConfiguration()) 96 # Create logical devices from the last device since sometimes the first GPU 97 # is the primary graphic card and may have less memory available. 98 config.set_logical_device_configuration(physical_devices[-1], logical_devices) 99 100 101def _set_logical_devices(): 102 if config.list_physical_devices("GPU"): 103 set_logical_devices_to_at_least("GPU", 2) 104 if config.list_physical_devices("CPU"): 105 set_logical_devices_to_at_least("CPU", 2) 106 107 108def main(enable_v2_behavior=True, config_logical_devices=True): 109 """All-in-one main function for tf.distribute tests.""" 110 if config_logical_devices: 111 app.call_after_init(_set_logical_devices) 112 if enable_v2_behavior: 113 v2_compat.enable_v2_behavior() 114 else: 115 v2_compat.disable_v2_behavior() 116 multi_process_runner.test_main() 117 118 119def _op_dependencies(op): 120 """Returns the data and control dependencies of a tf.Operation combined.""" 121 deps = [] 122 for node in itertools.chain(op.inputs, op.control_inputs): 123 if isinstance(node, ops.Tensor): 124 node = node.op 125 assert isinstance(node, ops.Operation) 126 deps.append(node) 127 return deps 128 129 130def topological_sort_operations(operations): 131 """Topological sorts a list of operations. 132 133 This does a topological sort of the operations in a graph. The edges include 134 both data dependencies and control dependencies. Note that the edge goes from 135 an operation to its dependencies. 136 137 Args: 138 operations: a list of tf.Operation in the same graph. 139 140 Returns: 141 A map from a tf.Operation to its topological order. 142 """ 143 in_degrees = {} 144 for op in operations: 145 if op not in in_degrees: 146 in_degrees[op] = 0 147 for next_op in _op_dependencies(op): 148 in_degrees[next_op] = in_degrees.get(next_op, 0) + 1 149 nexts = [] 150 for op, in_degree in in_degrees.items(): 151 if in_degree == 0: 152 nexts.append(op) 153 order = {} 154 next_order = 0 155 while nexts: 156 op, nexts = nexts[0], nexts[1:] 157 order[op] = next_order 158 next_order += 1 159 for next_op in _op_dependencies(op): 160 in_degrees[next_op] -= 1 161 if in_degrees[next_op] == 0: 162 nexts.append(next_op) 163 assert len(order) == len(operations) 164 return order 165 166 167def _exists_dependency(start, end): 168 """Returns whether there exists a dependency chain from start to end.""" 169 nexts = [start] 170 while nexts: 171 op, nexts = nexts[0], nexts[1:] 172 for next_op in _op_dependencies(op): 173 if next_op == end: 174 return True 175 nexts.append(next_op) 176 return False 177 178 179def assert_sequential_execution(order, operations): 180 """Asserts there's a deterministic execution order between the operations. 181 182 Args: 183 order: a map from a tf.Operation to its topological order. 184 operations: a list of operations that should be executed sequentially. It 185 can be given in any order. 186 """ 187 # Topological ordering guarantees that, if there's a dependency from N_a to 188 # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies 189 # among the operations, it always goes from a operation with a smaller 190 # topological order to one with a larger topological order. Therefore, we only 191 # need to sort the operations by their topological orders, and verify that 192 # there's a path of dependency between adjacent pairs. 193 operations = sorted(operations, key=lambda op: order[op]) 194 for i in range(len(operations) - 1): 195 if not _exists_dependency(operations[i], operations[i + 1]): 196 print(operations[i].graph.as_graph_def()) 197 raise AssertionError( 198 "No dependency between {} and {}. Graph is dumped to stdout.".format( 199 operations[i].name, operations[i + 1].name)) 200 201 202def get_running_threads(): 203 """Returns a set of all running thread names.""" 204 running_threads = set() 205 for thread in threading.enumerate(): 206 if thread.name is not None: 207 running_threads.add(thread.name) 208 return running_threads 209 210 211def has_thread(prefix, running_threads): 212 """Returns whether any 'running_threads' is prefixed with 'prefix'. 213 214 Args: 215 prefix: The prefix of the expected thread name. 216 running_threads: A collection of the running thread names. 217 """ 218 for thread in running_threads: 219 if thread.startswith(prefix): 220 return True 221 return False 222 223 224def show_backref(target, max_depth=3): 225 """Returns a dot graph of all the objects that are referencing the target. 226 227 A object referencing graph is useful to debug memory leak like circular 228 reference. objgraph provides a good visualization of the memory graph than 229 most python built-in utilities like gc.get_referrers(), which are not 230 human-readable sometimes. 231 232 The dot graph will be written to a string IO object, and can be rendered with 233 graphviz in operating system. 234 E.g. dot -Tpng {$dot_graph} -o output.png 235 Args: 236 target: The target object for the memory graph. 237 max_depth: The maximum depth of the graph. By default 3 layers of references 238 are used. Increases this a lot may result in the graph growing too big. 239 240 Returns: 241 A string that contains the object reference graph. 242 Raises: 243 NotImplementedError: if objgraph is not installed. 244 """ 245 if objgraph is None: 246 raise NotImplementedError("objgraph is not installed.") 247 string_io = io.StringIO() 248 objgraph.show_backrefs(target, max_depth=max_depth, output=string_io) 249 graph = string_io.getvalue() 250 string_io.close() 251 return graph 252 253 254def create_per_replica(strategy, value_list): 255 """Creates a PerReplica of Tensors from the value_list.""" 256 if len(strategy.extended.worker_devices) != len(value_list): 257 raise ValueError( 258 "the length of values must be the same as the number of worker devices") 259 tensors = [] 260 for device, value in zip(strategy.extended.worker_devices, value_list): 261 with ops.device(device): 262 tensors.append(ops.convert_to_tensor(value)) 263 return values.PerReplica(tensors) 264 265 266def is_tpu_strategy(strategy): 267 """Returns whether the strategy is a TPU strategy.""" 268 return isinstance(strategy, 269 (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1, 270 tpu_strategy.TPUStrategyV2)) 271