• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import io
23import itertools
24import threading
25
26from absl import app
27
28from tensorflow.python.compat import v2_compat
29from tensorflow.python.distribute import collective_all_reduce_strategy
30from tensorflow.python.distribute import multi_process_runner
31from tensorflow.python.distribute import tpu_strategy
32from tensorflow.python.distribute import values
33from tensorflow.python.eager import context
34from tensorflow.python.framework import config
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.util import nest
38
39try:
40  import objgraph  # pylint:disable=g-import-not-at-top
41except ImportError:
42  objgraph = None
43
44
45def gather(strategy, value):
46  """Gathers value from all workers.
47
48  This is intended for tests before we implement an official all-gather API.
49
50  Args:
51    strategy: a `tf.distribute.Strategy`.
52    value: a nested structure of n-dim `tf.distribute.DistributedValue` of
53      `tf.Tensor`, or of a `tf.Tensor` if the strategy only has one replica.
54      Cannot contain tf.sparse.SparseTensor.
55
56  Returns:
57    a (n+1)-dim `tf.Tensor`.
58  """
59  return nest.map_structure(functools.partial(_gather, strategy), value)
60
61
62def _gather(strategy, value):
63  """Gathers a single value."""
64  # pylint: disable=protected-access
65  if not isinstance(value, values.DistributedValues):
66    value = values.PerReplica([ops.convert_to_tensor(value)])
67  if not isinstance(strategy.extended,
68                    collective_all_reduce_strategy.CollectiveAllReduceExtended):
69    return array_ops.stack(value._values)
70  assert len(strategy.extended.worker_devices) == len(value._values)
71  inputs = [array_ops.expand_dims_v2(v, axis=0) for v in value._values]
72  return strategy.gather(values.PerReplica(inputs), axis=0)
73  # pylint: enable=protected-access
74
75
76def set_logical_devices_to_at_least(device, num):
77  """Create logical devices of at least a given number."""
78  if num < 1:
79    raise ValueError("`num` must be at least 1 not %r" % (num,))
80  physical_devices = config.list_physical_devices(device)
81  if not physical_devices:
82    raise RuntimeError("No {} found".format(device))
83  if len(physical_devices) >= num:
84    return
85  # By default each physical device corresponds to one logical device. We create
86  # multiple logical devices for the last physical device so that we have `num`
87  # logical devices.
88  num = num - len(physical_devices) + 1
89  logical_devices = []
90  for _ in range(num):
91    if device.upper() == "GPU":
92      logical_devices.append(
93          context.LogicalDeviceConfiguration(memory_limit=2048))
94    else:
95      logical_devices.append(context.LogicalDeviceConfiguration())
96  # Create logical devices from the last device since sometimes the first GPU
97  # is the primary graphic card and may have less memory available.
98  config.set_logical_device_configuration(physical_devices[-1], logical_devices)
99
100
101def _set_logical_devices():
102  if config.list_physical_devices("GPU"):
103    set_logical_devices_to_at_least("GPU", 2)
104  if config.list_physical_devices("CPU"):
105    set_logical_devices_to_at_least("CPU", 2)
106
107
108def main(enable_v2_behavior=True, config_logical_devices=True):
109  """All-in-one main function for tf.distribute tests."""
110  if config_logical_devices:
111    app.call_after_init(_set_logical_devices)
112  if enable_v2_behavior:
113    v2_compat.enable_v2_behavior()
114  else:
115    v2_compat.disable_v2_behavior()
116  multi_process_runner.test_main()
117
118
119def _op_dependencies(op):
120  """Returns the data and control dependencies of a tf.Operation combined."""
121  deps = []
122  for node in itertools.chain(op.inputs, op.control_inputs):
123    if isinstance(node, ops.Tensor):
124      node = node.op
125    assert isinstance(node, ops.Operation)
126    deps.append(node)
127  return deps
128
129
130def topological_sort_operations(operations):
131  """Topological sorts a list of operations.
132
133  This does a topological sort of the operations in a graph. The edges include
134  both data dependencies and control dependencies. Note that the edge goes from
135  an operation to its dependencies.
136
137  Args:
138    operations: a list of tf.Operation in the same graph.
139
140  Returns:
141    A map from a tf.Operation to its topological order.
142  """
143  in_degrees = {}
144  for op in operations:
145    if op not in in_degrees:
146      in_degrees[op] = 0
147    for next_op in _op_dependencies(op):
148      in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
149  nexts = []
150  for op, in_degree in in_degrees.items():
151    if in_degree == 0:
152      nexts.append(op)
153  order = {}
154  next_order = 0
155  while nexts:
156    op, nexts = nexts[0], nexts[1:]
157    order[op] = next_order
158    next_order += 1
159    for next_op in _op_dependencies(op):
160      in_degrees[next_op] -= 1
161      if in_degrees[next_op] == 0:
162        nexts.append(next_op)
163  assert len(order) == len(operations)
164  return order
165
166
167def _exists_dependency(start, end):
168  """Returns whether there exists a dependency chain from start to end."""
169  nexts = [start]
170  while nexts:
171    op, nexts = nexts[0], nexts[1:]
172    for next_op in _op_dependencies(op):
173      if next_op == end:
174        return True
175      nexts.append(next_op)
176  return False
177
178
179def assert_sequential_execution(order, operations):
180  """Asserts there's a deterministic execution order between the operations.
181
182  Args:
183    order: a map from a tf.Operation to its topological order.
184    operations: a list of operations that should be executed sequentially. It
185      can be given in any order.
186  """
187  # Topological ordering guarantees that, if there's a dependency from N_a to
188  # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
189  # among the operations, it always goes from a operation with a smaller
190  # topological order to one with a larger topological order. Therefore, we only
191  # need to sort the operations by their topological orders, and verify that
192  # there's a path of dependency between adjacent pairs.
193  operations = sorted(operations, key=lambda op: order[op])
194  for i in range(len(operations) - 1):
195    if not _exists_dependency(operations[i], operations[i + 1]):
196      print(operations[i].graph.as_graph_def())
197      raise AssertionError(
198          "No dependency between {} and {}. Graph is dumped to stdout.".format(
199              operations[i].name, operations[i + 1].name))
200
201
202def get_running_threads():
203  """Returns a set of all running thread names."""
204  running_threads = set()
205  for thread in threading.enumerate():
206    if thread.name is not None:
207      running_threads.add(thread.name)
208  return running_threads
209
210
211def has_thread(prefix, running_threads):
212  """Returns whether any 'running_threads' is prefixed with 'prefix'.
213
214  Args:
215    prefix: The prefix of the expected thread name.
216    running_threads: A collection of the running thread names.
217  """
218  for thread in running_threads:
219    if thread.startswith(prefix):
220      return True
221  return False
222
223
224def show_backref(target, max_depth=3):
225  """Returns a dot graph of all the objects that are referencing the target.
226
227  A object referencing graph is useful to debug memory leak like circular
228  reference. objgraph provides a good visualization of the memory graph than
229  most python built-in utilities like gc.get_referrers(), which are not
230  human-readable sometimes.
231
232  The dot graph will be written to a string IO object, and can be rendered with
233  graphviz in operating system.
234  E.g. dot -Tpng {$dot_graph} -o output.png
235  Args:
236    target: The target object for the memory graph.
237    max_depth: The maximum depth of the graph. By default 3 layers of references
238    are used. Increases this a lot may result in the graph growing too big.
239
240  Returns:
241    A string that contains the object reference graph.
242  Raises:
243    NotImplementedError: if objgraph is not installed.
244  """
245  if objgraph is None:
246    raise NotImplementedError("objgraph is not installed.")
247  string_io = io.StringIO()
248  objgraph.show_backrefs(target, max_depth=max_depth, output=string_io)
249  graph = string_io.getvalue()
250  string_io.close()
251  return graph
252
253
254def create_per_replica(strategy, value_list):
255  """Creates a PerReplica of Tensors from the value_list."""
256  if len(strategy.extended.worker_devices) != len(value_list):
257    raise ValueError(
258        "the length of values must be the same as the number of worker devices")
259  tensors = []
260  for device, value in zip(strategy.extended.worker_devices, value_list):
261    with ops.device(device):
262      tensors.append(ops.convert_to_tensor(value))
263  return values.PerReplica(tensors)
264
265
266def is_tpu_strategy(strategy):
267  """Returns whether the strategy is a TPU strategy."""
268  return isinstance(strategy,
269                    (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1,
270                     tpu_strategy.TPUStrategyV2))
271