1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import itertools 23 24from absl import app 25 26from tensorflow.python.compat import v2_compat 27from tensorflow.python.distribute import collective_all_reduce_strategy 28from tensorflow.python.distribute import multi_process_runner 29from tensorflow.python.distribute import values 30from tensorflow.python.eager import context 31from tensorflow.python.framework import config 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.util import nest 35 36 37def gather(strategy, value): 38 """Gathers value from all workers. 39 40 This is intended for tests before we implement an official all-gather API. 41 42 Args: 43 strategy: a `tf.distribute.Strategy`. 44 value: a nested structure of n-dim `tf.distribute.DistributedValue` of 45 `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica. 46 Cannot contain tf.sparse.SparseTensor. 47 48 Returns: 49 a (n+1)-dim `tf.Tensor`. 50 """ 51 return nest.map_structure(functools.partial(_gather, strategy), value) 52 53 54def _gather(strategy, value): 55 """Gathers a single value.""" 56 # pylint: disable=protected-access 57 if not isinstance(value, values.DistributedValues): 58 value = values.PerReplica([ops.convert_to_tensor(value)]) 59 if not isinstance(strategy.extended, 60 collective_all_reduce_strategy.CollectiveAllReduceExtended): 61 return array_ops.stack(value._values) 62 assert len(strategy.extended.worker_devices) == len(value._values) 63 inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values] 64 return strategy.gather(values.PerReplica(inputs), axis=0) 65 # pylint: enable=protected-access 66 67 68def set_logical_devices_to_at_least(device, num): 69 """Create logical devices of at least a given number.""" 70 if num < 1: 71 raise ValueError("`num` must be at least 1 not %r" % (num,)) 72 physical_devices = config.list_physical_devices(device) 73 if not physical_devices: 74 raise RuntimeError("No {} found".format(device)) 75 if len(physical_devices) >= num: 76 return 77 # By default each physical device corresponds to one logical device. We create 78 # multiple logical devices for the last physical device so that we have `num` 79 # logical devices. 80 num = num - len(physical_devices) + 1 81 logical_devices = [] 82 for _ in range(num): 83 if device.upper() == "GPU": 84 logical_devices.append( 85 context.LogicalDeviceConfiguration(memory_limit=2048)) 86 else: 87 logical_devices.append(context.LogicalDeviceConfiguration()) 88 # Create logical devices from the last device since sometimes the first GPU 89 # is the primary graphic card and may have less memory available. 90 config.set_logical_device_configuration(physical_devices[-1], logical_devices) 91 92 93def _set_logical_devices(): 94 if config.list_physical_devices("GPU"): 95 set_logical_devices_to_at_least("GPU", 2) 96 if config.list_physical_devices("CPU"): 97 set_logical_devices_to_at_least("CPU", 2) 98 99 100def main(enable_v2_behavior=True, config_logical_devices=True): 101 """All-in-one main function for tf.distribute tests.""" 102 if config_logical_devices: 103 app.call_after_init(_set_logical_devices) 104 if enable_v2_behavior: 105 v2_compat.enable_v2_behavior() 106 else: 107 v2_compat.disable_v2_behavior() 108 # TODO(b/131360402): configure default logical devices. 109 multi_process_runner.test_main() 110 111 112def _op_dependencies(op): 113 """Returns the data and control dependencies of a tf.Operation combined.""" 114 deps = [] 115 for node in itertools.chain(op.inputs, op.control_inputs): 116 if isinstance(node, ops.Tensor): 117 node = node.op 118 assert isinstance(node, ops.Operation) 119 deps.append(node) 120 return deps 121 122 123def topological_sort_operations(operations): 124 """Topological sorts a list of operations. 125 126 This does a topological sort of the operations in a graph. The edges include 127 both data dependencies and control dependencies. Note that the edge goes from 128 an operation to its dependencies. 129 130 Args: 131 operations: a list of tf.Operation in the same graph. 132 133 Returns: 134 A map from a tf.Operation to its topological order. 135 """ 136 in_degrees = {} 137 for op in operations: 138 if op not in in_degrees: 139 in_degrees[op] = 0 140 for next_op in _op_dependencies(op): 141 in_degrees[next_op] = in_degrees.get(next_op, 0) + 1 142 nexts = [] 143 for op, in_degree in in_degrees.items(): 144 if in_degree == 0: 145 nexts.append(op) 146 order = {} 147 next_order = 0 148 while nexts: 149 op, nexts = nexts[0], nexts[1:] 150 order[op] = next_order 151 next_order += 1 152 for next_op in _op_dependencies(op): 153 in_degrees[next_op] -= 1 154 if in_degrees[next_op] == 0: 155 nexts.append(next_op) 156 assert len(order) == len(operations) 157 return order 158 159 160def _exists_dependency(start, end): 161 """Returns whether there exists a dependency chain from start to end.""" 162 nexts = [start] 163 while nexts: 164 op, nexts = nexts[0], nexts[1:] 165 for next_op in _op_dependencies(op): 166 if next_op == end: 167 return True 168 nexts.append(next_op) 169 return False 170 171 172def assert_sequential_execution(order, operations): 173 """Asserts there's a deterministic execution order between the operations. 174 175 Args: 176 order: a map from a tf.Operation to its topological order. 177 operations: a list of operations that should be executed sequentially. It 178 can be given in any order. 179 """ 180 # Topological ordering guarantees that, if there's a dependency from N_a to 181 # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies 182 # among the operations, it always goes from a operation with a smaller 183 # topological order to one with a larger topological order. Therefore, we only 184 # need to sort the operations by their topological orders, and verify that 185 # there's a path of dependency between adjacent pairs. 186 operations = sorted(operations, key=lambda op: order[op]) 187 for i in range(len(operations) - 1): 188 if not _exists_dependency(operations[i], operations[i + 1]): 189 print(operations[i].graph.as_graph_def()) 190 raise AssertionError( 191 "No dependency between {} and {}. Graph is dumped to stdout.".format( 192 operations[i].name, operations[i + 1].name)) 193