• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import itertools
23
24from absl import app
25
26from tensorflow.python.compat import v2_compat
27from tensorflow.python.distribute import collective_all_reduce_strategy
28from tensorflow.python.distribute import multi_process_runner
29from tensorflow.python.distribute import values
30from tensorflow.python.eager import context
31from tensorflow.python.framework import config
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.util import nest
35
36
37def gather(strategy, value):
38  """Gathers value from all workers.
39
40  This is intended for tests before we implement an official all-gather API.
41
42  Args:
43    strategy: a `tf.distribute.Strategy`.
44    value: a nested structure of n-dim `tf.distribute.DistributedValue` of
45      `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
46      Cannot contain tf.sparse.SparseTensor.
47
48  Returns:
49    a (n+1)-dim `tf.Tensor`.
50  """
51  return nest.map_structure(functools.partial(_gather, strategy), value)
52
53
54def _gather(strategy, value):
55  """Gathers a single value."""
56  # pylint: disable=protected-access
57  if not isinstance(value, values.DistributedValues):
58    value = values.PerReplica([ops.convert_to_tensor(value)])
59  if not isinstance(strategy.extended,
60                    collective_all_reduce_strategy.CollectiveAllReduceExtended):
61    return array_ops.stack(value._values)
62  assert len(strategy.extended.worker_devices) == len(value._values)
63  inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
64  return strategy.gather(values.PerReplica(inputs), axis=0)
65  # pylint: enable=protected-access
66
67
68def set_logical_devices_to_at_least(device, num):
69  """Create logical devices of at least a given number."""
70  if num < 1:
71    raise ValueError("`num` must be at least 1 not %r" % (num,))
72  physical_devices = config.list_physical_devices(device)
73  if not physical_devices:
74    raise RuntimeError("No {} found".format(device))
75  if len(physical_devices) >= num:
76    return
77  # By default each physical device corresponds to one logical device. We create
78  # multiple logical devices for the last physical device so that we have `num`
79  # logical devices.
80  num = num - len(physical_devices) + 1
81  logical_devices = []
82  for _ in range(num):
83    if device.upper() == "GPU":
84      logical_devices.append(
85          context.LogicalDeviceConfiguration(memory_limit=2048))
86    else:
87      logical_devices.append(context.LogicalDeviceConfiguration())
88  # Create logical devices from the last device since sometimes the first GPU
89  # is the primary graphic card and may have less memory available.
90  config.set_logical_device_configuration(physical_devices[-1], logical_devices)
91
92
93def _set_logical_devices():
94  if config.list_physical_devices("GPU"):
95    set_logical_devices_to_at_least("GPU", 2)
96  if config.list_physical_devices("CPU"):
97    set_logical_devices_to_at_least("CPU", 2)
98
99
100def main(enable_v2_behavior=True, config_logical_devices=True):
101  """All-in-one main function for tf.distribute tests."""
102  if config_logical_devices:
103    app.call_after_init(_set_logical_devices)
104  if enable_v2_behavior:
105    v2_compat.enable_v2_behavior()
106  else:
107    v2_compat.disable_v2_behavior()
108  # TODO(b/131360402): configure default logical devices.
109  multi_process_runner.test_main()
110
111
112def _op_dependencies(op):
113  """Returns the data and control dependencies of a tf.Operation combined."""
114  deps = []
115  for node in itertools.chain(op.inputs, op.control_inputs):
116    if isinstance(node, ops.Tensor):
117      node = node.op
118    assert isinstance(node, ops.Operation)
119    deps.append(node)
120  return deps
121
122
123def topological_sort_operations(operations):
124  """Topological sorts a list of operations.
125
126  This does a topological sort of the operations in a graph. The edges include
127  both data dependencies and control dependencies. Note that the edge goes from
128  an operation to its dependencies.
129
130  Args:
131    operations: a list of tf.Operation in the same graph.
132
133  Returns:
134    A map from a tf.Operation to its topological order.
135  """
136  in_degrees = {}
137  for op in operations:
138    if op not in in_degrees:
139      in_degrees[op] = 0
140    for next_op in _op_dependencies(op):
141      in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
142  nexts = []
143  for op, in_degree in in_degrees.items():
144    if in_degree == 0:
145      nexts.append(op)
146  order = {}
147  next_order = 0
148  while nexts:
149    op, nexts = nexts[0], nexts[1:]
150    order[op] = next_order
151    next_order += 1
152    for next_op in _op_dependencies(op):
153      in_degrees[next_op] -= 1
154      if in_degrees[next_op] == 0:
155        nexts.append(next_op)
156  assert len(order) == len(operations)
157  return order
158
159
160def _exists_dependency(start, end):
161  """Returns whether there exists a dependency chain from start to end."""
162  nexts = [start]
163  while nexts:
164    op, nexts = nexts[0], nexts[1:]
165    for next_op in _op_dependencies(op):
166      if next_op == end:
167        return True
168      nexts.append(next_op)
169  return False
170
171
172def assert_sequential_execution(order, operations):
173  """Asserts there's a deterministic execution order between the operations.
174
175  Args:
176    order: a map from a tf.Operation to its topological order.
177    operations: a list of operations that should be executed sequentially. It
178      can be given in any order.
179  """
180  # Topological ordering guarantees that, if there's a dependency from N_a to
181  # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
182  # among the operations, it always goes from a operation with a smaller
183  # topological order to one with a larger topological order. Therefore, we only
184  # need to sort the operations by their topological orders, and verify that
185  # there's a path of dependency between adjacent pairs.
186  operations = sorted(operations, key=lambda op: order[op])
187  for i in range(len(operations) - 1):
188    if not _exists_dependency(operations[i], operations[i + 1]):
189      print(operations[i].graph.as_graph_def())
190      raise AssertionError(
191          "No dependency between {} and {}. Graph is dumped to stdout.".format(
192              operations[i].name, operations[i + 1].name))
193