• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23from collections import OrderedDict
24import contextlib
25import functools
26import gc
27import itertools
28import math
29import os
30import random
31import re
32import tempfile
33import threading
34import time
35import unittest
36
37from absl.testing import parameterized
38import numpy as np
39import six
40
41from google.protobuf import descriptor_pool
42from google.protobuf import text_format
43
44from tensorflow.core.framework import graph_pb2
45from tensorflow.core.protobuf import rewriter_config_pb2
46from tensorflow.python import tf2
47from tensorflow.python.client import device_lib
48from tensorflow.python.client import pywrap_tf_session
49from tensorflow.python.client import session
50from tensorflow.python.compat.compat import forward_compatibility_horizon
51from tensorflow.python.eager import backprop
52from tensorflow.python.eager import context
53from tensorflow.python.eager import def_function
54from tensorflow.python.eager import tape
55from tensorflow.python.framework import config
56from tensorflow.python.framework import device as pydev
57from tensorflow.python.framework import dtypes
58from tensorflow.python.framework import errors
59from tensorflow.python.framework import errors_impl
60from tensorflow.python.framework import gpu_util
61from tensorflow.python.framework import importer
62from tensorflow.python.framework import ops
63from tensorflow.python.framework import random_seed
64from tensorflow.python.framework import sparse_tensor
65from tensorflow.python.framework import tensor_shape
66from tensorflow.python.framework import tensor_util
67from tensorflow.python.framework import tfrt_utils
68from tensorflow.python.framework import versions
69from tensorflow.python.ops import array_ops
70from tensorflow.python.ops import control_flow_util
71from tensorflow.python.ops import control_flow_util_v2
72from tensorflow.python.ops import gradients_impl
73from tensorflow.python.ops import math_ops
74from tensorflow.python.ops import script_ops
75from tensorflow.python.ops import summary_ops_v2
76from tensorflow.python.ops import variables
77from tensorflow.python.ops.ragged import ragged_ops  # pylint: disable=unused-import
78from tensorflow.python.ops.ragged import ragged_tensor
79from tensorflow.python.ops.ragged import ragged_tensor_value
80from tensorflow.python.platform import _pywrap_stacktrace_handler
81from tensorflow.python.platform import googletest
82from tensorflow.python.platform import tf_logging as logging
83from tensorflow.python.training import server_lib
84from tensorflow.python.util import _pywrap_util_port
85from tensorflow.python.util import compat
86from tensorflow.python.util import deprecation
87from tensorflow.python.util import nest
88from tensorflow.python.util import tf_decorator
89from tensorflow.python.util import tf_inspect
90from tensorflow.python.util.compat import collections_abc
91from tensorflow.python.util.protobuf import compare
92from tensorflow.python.util.tf_export import tf_export
93
94
95# If the below import is made available through the BUILD rule, then this
96# function is overridden and will instead return True and cause Tensorflow
97# graphs to be compiled with XLA.
98def is_xla_enabled():
99  return False
100
101
102try:
103  from tensorflow.python.framework.is_xla_test_true import is_xla_enabled  # pylint: disable=g-import-not-at-top, unused-import
104except Exception:  # pylint: disable=broad-except
105  pass
106
107
108# Uses the same mechanism as above to selectively enable/disable MLIR
109# compilation.
110def is_mlir_bridge_enabled():
111  return None
112
113
114try:
115  from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled  # pylint: disable=g-import-not-at-top, unused-import
116except ImportError:
117  try:
118    from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled  # pylint: disable=g-import-not-at-top, unused-import
119  except ImportError:
120    pass
121
122
123def _get_object_count_by_type(exclude=()):
124  return (
125      collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
126      collections.Counter([type(obj).__name__ for obj in exclude]))
127
128
129@tf_export("test.gpu_device_name")
130def gpu_device_name():
131  """Returns the name of a GPU device if available or the empty string."""
132  for x in device_lib.list_local_devices():
133    if x.device_type == "GPU":
134      return compat.as_str(x.name)
135  return ""
136
137
138def assert_ops_in_graph(expected_ops, graph):
139  """Assert all expected operations are found.
140
141  Args:
142    expected_ops: `dict<string, string>` of op name to op type.
143    graph: Graph to check.
144
145  Returns:
146    `dict<string, node>` of node name to node.
147
148  Raises:
149    ValueError: If the expected ops are not present in the graph.
150  """
151  actual_ops = {}
152  gd = graph.as_graph_def()
153  for node in gd.node:
154    if node.name in expected_ops:
155      if expected_ops[node.name] != node.op:
156        raise ValueError("Expected op for node %s is different. %s vs %s" %
157                         (node.name, expected_ops[node.name], node.op))
158      actual_ops[node.name] = node
159  if set(expected_ops.keys()) != set(actual_ops.keys()):
160    raise ValueError("Not all expected ops are present. Expected %s, found %s" %
161                     (expected_ops.keys(), actual_ops.keys()))
162  return actual_ops
163
164
165@tf_export("test.assert_equal_graph_def", v1=[])
166def assert_equal_graph_def_v2(expected, actual):
167  """Asserts that two `GraphDef`s are (mostly) the same.
168
169  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
170  nodes, attrs, and control inputs.  Node names are used to match up nodes
171  between the graphs, so the naming of nodes must be consistent. This function
172  ignores randomized attribute values that may appear in V2 checkpoints.
173
174  Args:
175    expected: The `GraphDef` we expected.
176    actual: The `GraphDef` we have.
177
178  Raises:
179    AssertionError: If the `GraphDef`s do not match.
180    TypeError: If either argument is not a `GraphDef`.
181  """
182  assert_equal_graph_def(actual, expected, checkpoint_v2=True,
183                         hash_table_shared_name=True)
184
185
186@tf_export(v1=["test.assert_equal_graph_def"])
187def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False,
188                              hash_table_shared_name=False):
189  """Asserts that two `GraphDef`s are (mostly) the same.
190
191  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
192  nodes, attrs, and control inputs.  Node names are used to match up nodes
193  between the graphs, so the naming of nodes must be consistent.
194
195  Args:
196    actual: The `GraphDef` we have.
197    expected: The `GraphDef` we expected.
198    checkpoint_v2: boolean determining whether to ignore randomized attribute
199      values that appear in V2 checkpoints.
200    hash_table_shared_name: boolean determining whether to ignore randomized
201      shared_names that appear in HashTableV2 op defs.
202
203  Raises:
204    AssertionError: If the `GraphDef`s do not match.
205    TypeError: If either argument is not a `GraphDef`.
206  """
207  assert_equal_graph_def(actual, expected, checkpoint_v2,
208                         hash_table_shared_name)
209
210
211def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
212                           hash_table_shared_name=False):
213  if not isinstance(actual, graph_pb2.GraphDef):
214    raise TypeError("Expected tf.GraphDef for actual, got %s" %
215                    type(actual).__name__)
216  if not isinstance(expected, graph_pb2.GraphDef):
217    raise TypeError("Expected tf.GraphDef for expected, got %s" %
218                    type(expected).__name__)
219
220  if checkpoint_v2:
221    _strip_checkpoint_v2_randomized(actual)
222    _strip_checkpoint_v2_randomized(expected)
223
224  if hash_table_shared_name:
225    _strip_hash_table_shared_name(actual)
226    _strip_hash_table_shared_name(expected)
227
228  diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
229                                                expected.SerializeToString())
230  if diff:
231    raise AssertionError(compat.as_str(diff))
232
233
234def assert_meta_graph_protos_equal(tester, a, b):
235  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
236  # Carefully check the collection_defs
237  tester.assertEqual(set(a.collection_def), set(b.collection_def))
238  collection_keys = a.collection_def.keys()
239  for k in collection_keys:
240    a_value = a.collection_def[k]
241    b_value = b.collection_def[k]
242    proto_type = ops.get_collection_proto_type(k)
243    if proto_type:
244      a_proto = proto_type()
245      b_proto = proto_type()
246      # Number of entries in the collections is the same
247      tester.assertEqual(
248          len(a_value.bytes_list.value), len(b_value.bytes_list.value))
249      for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
250                                              b_value.bytes_list.value):
251        a_proto.ParseFromString(a_value_item)
252        b_proto.ParseFromString(b_value_item)
253        tester.assertProtoEquals(a_proto, b_proto)
254    else:
255      tester.assertEquals(a_value, b_value)
256  # Compared the fields directly, remove their raw values from the
257  # proto comparison below.
258  a.ClearField("collection_def")
259  b.ClearField("collection_def")
260
261  # Check the graph_defs.
262  assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
263  # Check graph_def versions (ignored by assert_equal_graph_def).
264  tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
265  # Compared the fields directly, remove their raw values from the
266  # proto comparison below.
267  a.ClearField("graph_def")
268  b.ClearField("graph_def")
269
270  tester.assertProtoEquals(a, b)
271
272
273# Matches attributes named via _SHARDED_SUFFIX in
274# tensorflow/python/training/saver.py
275_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
276
277
278def _strip_checkpoint_v2_randomized(graph_def):
279  for node in graph_def.node:
280    delete_keys = []
281    for attr_key in node.attr:
282      attr_tensor_value = node.attr[attr_key].tensor
283      if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
284        attr_tensor_string_value = attr_tensor_value.string_val[0]
285        if (attr_tensor_string_value and
286            re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN),
287                     attr_tensor_string_value)):
288          delete_keys.append(attr_key)
289    for attr_key in delete_keys:
290      del node.attr[attr_key]
291
292
293_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+"
294
295
296def _strip_hash_table_shared_name(graph_def):
297  for node in graph_def.node:
298    delete_keys = []
299    if node.op == "HashTableV2" and "shared_name" in node.attr:
300      if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN),
301                  node.attr["shared_name"].s):
302        delete_keys.append("shared_name")
303    for attr_key in delete_keys:
304      del node.attr[attr_key]
305
306
307def IsGoogleCudaEnabled():
308  return _pywrap_util_port.IsGoogleCudaEnabled()
309
310
311def IsBuiltWithROCm():
312  return _pywrap_util_port.IsBuiltWithROCm()
313
314
315def IsBuiltWithXLA():
316  return _pywrap_util_port.IsBuiltWithXLA()
317
318
319def IsBuiltWithNvcc():
320  return _pywrap_util_port.IsBuiltWithNvcc()
321
322
323def GpuSupportsHalfMatMulAndConv():
324  return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
325
326
327def IsMklEnabled():
328  return _pywrap_util_port.IsMklEnabled()
329
330
331def InstallStackTraceHandler():
332  _pywrap_stacktrace_handler.InstallStacktraceHandler()
333
334
335def NHWCToNCHW(input_tensor):
336  """Converts the input from the NHWC format to NCHW.
337
338  Args:
339    input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
340
341  Returns:
342    converted tensor or shape array
343  """
344  # tensor dim -> new axis order
345  new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
346  if isinstance(input_tensor, ops.Tensor):
347    ndims = input_tensor.shape.ndims
348    return array_ops.transpose(input_tensor, new_axes[ndims])
349  else:
350    ndims = len(input_tensor)
351    return [input_tensor[a] for a in new_axes[ndims]]
352
353
354def NHWCToNCHW_VECT_C(input_shape_or_tensor):
355  """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
356
357  Note: Does not include quantization or type conversion steps, which should
358  be applied afterwards.
359
360  Args:
361    input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
362
363  Returns:
364    tensor or shape array transformed into NCHW_VECT_C
365
366  Raises:
367    ValueError: if last dimension of `input_shape_or_tensor` is not evenly
368        divisible by 4.
369  """
370  permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
371  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
372  temp_shape = (
373      input_shape_or_tensor.shape.as_list()
374      if is_tensor else input_shape_or_tensor)
375  if temp_shape[-1] % 4 != 0:
376    raise ValueError(
377        "Last dimension of input must be evenly divisible by 4 to convert to "
378        "NCHW_VECT_C.")
379  temp_shape[-1] //= 4
380  temp_shape.append(4)
381  permutation = permutations[len(temp_shape)]
382  if is_tensor:
383    t = array_ops.reshape(input_shape_or_tensor, temp_shape)
384    return array_ops.transpose(t, permutation)
385  else:
386    return [temp_shape[a] for a in permutation]
387
388
389def NCHW_VECT_CToNHWC(input_shape_or_tensor):
390  """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
391
392  Note: Does not include de-quantization or type conversion steps, which should
393  be applied beforehand.
394
395  Args:
396    input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
397
398  Returns:
399    tensor or shape array transformed into NHWC
400
401  Raises:
402    ValueError: if last dimension of `input_shape_or_tensor` is not 4.
403  """
404  permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
405  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
406  input_shape = (
407      input_shape_or_tensor.shape.as_list()
408      if is_tensor else input_shape_or_tensor)
409  if input_shape[-1] != 4:
410    raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
411  permutation = permutations[len(input_shape)]
412  nhwc_shape = [input_shape[a] for a in permutation[:-1]]
413  nhwc_shape[-1] *= input_shape[-1]
414  if is_tensor:
415    t = array_ops.transpose(input_shape_or_tensor, permutation)
416    return array_ops.reshape(t, nhwc_shape)
417  else:
418    return nhwc_shape
419
420
421def NCHWToNHWC(input_tensor):
422  """Converts the input from the NCHW format to NHWC.
423
424  Args:
425    input_tensor: a 4- or 5-D tensor, or an array representing shape
426
427  Returns:
428    converted tensor or shape array
429  """
430  # tensor dim -> new axis order
431  new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
432  if isinstance(input_tensor, ops.Tensor):
433    ndims = input_tensor.shape.ndims
434    return array_ops.transpose(input_tensor, new_axes[ndims])
435  else:
436    ndims = len(input_tensor)
437    return [input_tensor[a] for a in new_axes[ndims]]
438
439
440def skip_if(condition):
441  """Skips the decorated function if condition is or evaluates to True.
442
443  Args:
444    condition: Either an expression that can be used in "if not condition"
445      statement, or a callable whose result should be a boolean.
446
447  Returns:
448    The wrapped function
449  """
450
451  def real_skip_if(fn):
452
453    def wrapper(*args, **kwargs):
454      if callable(condition):
455        skip = condition()
456      else:
457        skip = condition
458      if not skip:
459        return fn(*args, **kwargs)
460
461    return wrapper
462
463  return real_skip_if
464
465
466@contextlib.contextmanager
467def skip_if_error(test_obj, error_type, messages=None):
468  """Context manager to skip cases not considered failures by the tests.
469
470  Note that this does not work if used in setUpClass/tearDownClass.
471  Usage in setUp/tearDown works fine just like regular test methods.
472
473  Args:
474    test_obj: A test object provided as `self` in the test methods; this object
475      is usually an instance of `unittest.TestCase`'s subclass and should have
476      `skipTest` method.
477    error_type: The error type to skip. Note that if `messages` are given, both
478      `error_type` and `messages` need to match for the test to be skipped.
479    messages: Optional, a string or list of strings. If `None`, the test will be
480      skipped if `error_type` matches what is raised; otherwise, the test is
481      skipped if any of the `messages` is contained in the message of the error
482      raised, and `error_type` matches the error raised.
483
484  Yields:
485    Nothing.
486  """
487  if messages:
488    messages = nest.flatten(messages)
489  try:
490    yield
491  except error_type as e:
492    if not messages or any(message in str(e) for message in messages):
493      test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e)))
494    else:
495      raise
496
497
498def enable_c_shapes(fn):
499  """No-op. TODO(b/74620627): Remove this."""
500  return fn
501
502
503def with_c_shapes(cls):
504  """No-op. TODO(b/74620627): Remove this."""
505  return cls
506
507
508def enable_control_flow_v2(fn):
509  """Decorator for enabling CondV2 and WhileV2 on a test.
510
511  Note this enables using CondV2 and WhileV2 after running the test class's
512  setup/teardown methods.
513
514  In addition to this, callers must import the while_v2 module in order to set
515  the _while_v2 module in control_flow_ops.
516
517  Args:
518    fn: the function to be wrapped
519
520  Returns:
521    The wrapped function
522  """
523
524  def wrapper(*args, **kwargs):
525    enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
526    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
527    try:
528      return fn(*args, **kwargs)
529    finally:
530      control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
531
532  return wrapper
533
534
535def with_control_flow_v2(cls):
536  """Adds methods that call original methods with WhileV2 and CondV2 enabled.
537
538  Note this enables CondV2 and WhileV2 in new methods after running the test
539  class's setup method.
540
541  In addition to this, callers must import the while_v2 module in order to set
542  the _while_v2 module in control_flow_ops.
543
544  If a test function has _disable_control_flow_v2 attr set to True (using the
545  @disable_control_flow_v2 decorator), the v2 function is not generated for it.
546
547  Example:
548
549  @test_util.with_control_flow_v2
550  class ControlFlowTest(test.TestCase):
551
552    def testEnabledForV2(self):
553      ...
554
555    @test_util.disable_control_flow_v2("b/xyzabc")
556    def testDisabledForV2(self):
557      ...
558
559  Generated class:
560  class ControlFlowTest(test.TestCase):
561
562    def testEnabledForV2(self):
563      ...
564
565    def testEnabledForV2WithControlFlowV2(self):
566      // Enable V2 flags.
567      testEnabledForV2(self)
568      // Restore V2 flags.
569
570    def testDisabledForV2(self):
571      ...
572
573  Args:
574    cls: class to decorate
575
576  Returns:
577    cls with new test methods added
578  """
579  if control_flow_util.ENABLE_CONTROL_FLOW_V2:
580    return cls
581
582  for name, value in cls.__dict__.copy().items():
583    if (callable(value) and
584        name.startswith(unittest.TestLoader.testMethodPrefix) and
585        not getattr(value, "_disable_control_flow_v2", False)):
586      setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
587  return cls
588
589
590def disable_control_flow_v2(unused_msg):
591  """Decorator for a function in a with_control_flow_v2 enabled test class.
592
593  Blocks the function from being run with v2 control flow ops.
594
595  Args:
596    unused_msg: Reason for disabling.
597
598  Returns:
599    The wrapped function with _disable_control_flow_v2 attr set to True.
600  """
601
602  def wrapper(func):
603    func._disable_control_flow_v2 = True
604    return func
605
606  return wrapper
607
608
609def enable_output_all_intermediates(fn):
610  """Force-enable outputing all intermediates from functional control flow ops.
611
612  Args:
613    fn: the function to be wrapped
614
615  Returns:
616    The wrapped function
617  """
618
619  def wrapper(*args, **kwargs):
620    output_all_intermediates_old = \
621        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
622    control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
623    try:
624      return fn(*args, **kwargs)
625    finally:
626      control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
627          output_all_intermediates_old
628
629  return wrapper
630
631
632def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
633  """Decorator for asserting that no new Python objects persist after a test.
634
635  Runs the test multiple times executing eagerly, first as a warmup and then to
636  let objects accumulate. The warmup helps ignore caches which do not grow as
637  the test is run repeatedly.
638
639  Useful for checking that there are no missing Py_DECREFs in the C exercised by
640  a bit of Python.
641
642  Args:
643    func: The function to test.
644    warmup_iters: The numer of warmup iterations, excluded from measuring.
645
646  Returns:
647    The wrapped function performing the test.
648  """
649
650  def wrap_f(f):
651    def decorator(self, *args, **kwargs):
652      """Warms up, gets object counts, runs the test, checks for new objects."""
653      with context.eager_mode():
654        gc.disable()
655        # Run the test 2 times as warmup, in an attempt to fill up caches, which
656        # should not grow as the test is run repeatedly below.
657        #
658        # TODO(b/117156879): Running warmup twice is black magic; we have seen
659        # tests that fail with 1 warmup run, and pass with 2, on various
660        # versions of python2.7.x.
661        for _ in range(warmup_iters):
662          f(self, *args, **kwargs)
663        # Since we aren't in the normal test lifecycle, we need to manually run
664        # cleanups to clear out their object references.
665        self.doCleanups()
666
667        # Some objects are newly created by _get_object_count_by_type().  So
668        # create and save as a dummy variable to include it as a baseline.
669        obj_count_by_type = _get_object_count_by_type()
670        gc.collect()
671
672        # Make sure any registered functions are cleaned up in the C++ runtime.
673        registered_function_names = context.context().list_function_names()
674
675        # unittest.doCleanups adds to self._outcome with each unwound call.
676        # These objects are retained across gc collections so we exclude them
677        # from the object count calculation.
678        obj_count_by_type = _get_object_count_by_type(
679            exclude=gc.get_referents(self._outcome.errors,
680                                     self._outcome.skipped))
681
682        if ops.has_default_graph():
683          collection_sizes_before = {
684              collection: len(ops.get_collection(collection))
685              for collection in ops.get_default_graph().collections
686          }
687        for _ in range(3):
688          f(self, *args, **kwargs)
689        # Since we aren't in the normal test lifecycle, we need to manually run
690        # cleanups to clear out their object references.
691        self.doCleanups()
692        # Note that gc.get_objects misses anything that isn't subject to garbage
693        # collection (C types). Collections are a common source of leaks, so we
694        # test for collection sizes explicitly.
695        if ops.has_default_graph():
696          for collection_key in ops.get_default_graph().collections:
697            collection = ops.get_collection(collection_key)
698            size_before = collection_sizes_before.get(collection_key, 0)
699            if len(collection) > size_before:
700              raise AssertionError(
701                  ("Collection %s increased in size from "
702                   "%d to %d (current items %s).") %
703                  (collection_key, size_before, len(collection), collection))
704            # Make sure our collection checks don't show up as leaked memory by
705            # removing references to temporary variables.
706            del collection
707            del collection_key
708            del size_before
709          del collection_sizes_before
710        gc.collect()
711
712        # There should be no new Python objects hanging around.
713        obj_count_by_type = (
714            _get_object_count_by_type(
715                exclude=gc.get_referents(self._outcome.errors,
716                                         self._outcome.skipped)) -
717            obj_count_by_type)
718
719        # There should be no newly registered functions hanging around.
720        leftover_functions = (
721            context.context().list_function_names() - registered_function_names)
722        assert not leftover_functions, (
723            "The following functions were newly created: %s" %
724            leftover_functions)
725
726        # In some cases (specifically on MacOS), new_count is somehow
727        # smaller than previous_count.
728        # Using plain assert because not all classes using this decorator
729        # have assertLessEqual
730        assert not obj_count_by_type, (
731            "The following objects were newly created: %s" %
732            str(obj_count_by_type))
733        gc.enable()
734    return decorator
735
736  if func is None:
737    return wrap_f
738  else:
739    return wrap_f(func)
740
741
742def assert_no_new_tensors(f):
743  """Decorator for asserting that no new Tensors persist after a test.
744
745  Mainly useful for checking that code using the Python C API has correctly
746  manipulated reference counts.
747
748  Clears the caches that it knows about, runs the garbage collector, then checks
749  that there are no Tensor or Tensor-like objects still around. This includes
750  Tensors to which something still has a reference (e.g. from missing
751  Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
752  of the objects has __del__ defined).
753
754  Args:
755    f: The test case to run.
756
757  Returns:
758    The decorated test case.
759  """
760
761  def decorator(self, **kwargs):
762    """Finds existing Tensors, runs the test, checks for new Tensors."""
763
764    def _is_tensorflow_object(obj):
765      try:
766        return isinstance(obj,
767                          (ops.Tensor, variables.Variable,
768                           tensor_shape.Dimension, tensor_shape.TensorShape))
769      except (ReferenceError, AttributeError):
770        # If the object no longer exists, we don't care about it.
771        return False
772
773    tensors_before = set(
774        id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
775    outside_executed_eagerly = context.executing_eagerly()
776    # Run the test in a new graph so that collections get cleared when it's
777    # done, but inherit the graph key so optimizers behave.
778    outside_graph_key = ops.get_default_graph()._graph_key
779    with ops.Graph().as_default():
780      ops.get_default_graph()._graph_key = outside_graph_key
781      if outside_executed_eagerly:
782        with context.eager_mode():
783          result = f(self, **kwargs)
784      else:
785        result = f(self, **kwargs)
786    # Make an effort to clear caches, which would otherwise look like leaked
787    # Tensors.
788    context.context()._clear_caches()  # pylint: disable=protected-access
789    gc.collect()
790    tensors_after = [
791        obj for obj in gc.get_objects()
792        if _is_tensorflow_object(obj) and id(obj) not in tensors_before
793    ]
794    if tensors_after:
795      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
796          len(tensors_after),
797          str(tensors_after),
798      )))
799    return result
800
801  return decorator
802
803
804def _find_reference_cycle(objects, idx):
805
806  def get_ignore_reason(obj, denylist):
807    """Tests whether an object should be omitted from the dependency graph."""
808    if len(denylist) > 100:
809      return "<depth limit>"
810    if tf_inspect.isframe(obj):
811      if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
812        return "<test code>"
813    for b in denylist:
814      if b is obj:
815        return "<test code>"
816    if obj is denylist:
817      return "<test code>"
818    return None
819
820  # Note: this function is meant to help with diagnostics. Its output is purely
821  # a human-readable representation, so you may freely modify it to suit your
822  # needs.
823  def describe(obj, denylist, leaves_only=False):
824    """Returns a custom human-readable summary of obj.
825
826    Args:
827      obj: the value to describe.
828      denylist: same as denylist in get_ignore_reason.
829      leaves_only: boolean flag used when calling describe recursively. Useful
830        for summarizing collections.
831    """
832    if get_ignore_reason(obj, denylist):
833      return "{}{}".format(get_ignore_reason(obj, denylist), type(obj))
834    if tf_inspect.isframe(obj):
835      return "frame: {}".format(tf_inspect.getframeinfo(obj))
836    elif tf_inspect.ismodule(obj):
837      return "module: {}".format(obj.__name__)
838    else:
839      if leaves_only:
840        return "{}, {}".format(type(obj), id(obj))
841      elif isinstance(obj, list):
842        return "list({}): {}".format(
843            id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
844      elif isinstance(obj, tuple):
845        return "tuple({}): {}".format(
846            id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
847      elif isinstance(obj, dict):
848        return "dict({}): {} keys".format(id(obj), len(obj.keys()))
849      elif tf_inspect.isfunction(obj):
850        return "function({}) {}; globals ID: {}".format(
851            id(obj), obj.__name__, id(obj.__globals__))
852      else:
853        return "{}, {}".format(type(obj), id(obj))
854
855  def build_ref_graph(obj, graph, reprs, denylist):
856    """Builds a reference graph as <referrer> -> <list of referents>.
857
858    Args:
859      obj: The object to start from. The graph will be built by recursively
860        adding its referrers.
861      graph: Dict holding the graph to be built. To avoid creating extra
862        references, the graph holds object IDs rather than actual objects.
863      reprs: Auxiliary structure that maps object IDs to their human-readable
864        description.
865      denylist: List of objects to ignore.
866    """
867    referrers = gc.get_referrers(obj)
868    denylist = denylist + (referrers,)
869
870    obj_id = id(obj)
871    for r in referrers:
872      if get_ignore_reason(r, denylist) is None:
873        r_id = id(r)
874        if r_id not in graph:
875          graph[r_id] = []
876        if obj_id not in graph[r_id]:
877          graph[r_id].append(obj_id)
878          build_ref_graph(r, graph, reprs, denylist)
879          reprs[r_id] = describe(r, denylist)
880
881  def find_cycle(el, graph, reprs, path):
882    """Finds and prints a single cycle in the dependency graph."""
883    if el not in graph:
884      return
885    for r in graph[el]:
886      if r in path:
887        logging.error("Reference cycle sample:")
888        for p in path + (r,):
889          logging.error(reprs.get(p, "unknown object " + str(p)))
890        return True
891      else:
892        if find_cycle(r, graph, reprs, path + (r,)):
893          return True
894    return False
895
896  obj = objects[idx]
897  graph = {}  # referrer ID -> object ID
898  reprs = {}  # object ID -> description
899  build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason,
900                                      describe, build_ref_graph, find_cycle))
901  for k in graph:
902    if find_cycle(k, graph, reprs, ()):
903      return True
904  return False
905
906
907def assert_no_garbage_created(f):
908  """Test method decorator to assert that no garbage has been created.
909
910  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
911  cannot be un-set (i.e. will disable garbage collection for any other unit
912  tests in the same file/shard).
913
914  Args:
915    f: The function to decorate.
916
917  Returns:
918    The decorated function.
919  """
920
921  def decorator(self, **kwargs):
922    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
923    # Force-load `distribution_strategy_context` to prevent GC at
924    # test time when using eager. Remove once b/117329403 is resolved.
925    tape.distribution_strategy_context.get_strategy()
926
927    gc.disable()
928    previous_debug_flags = gc.get_debug()
929    gc.set_debug(gc.DEBUG_SAVEALL)
930    gc.collect()
931    previous_garbage = len(gc.garbage)
932    result = f(self, **kwargs)
933    gc.collect()
934    new_garbage = len(gc.garbage)
935    if new_garbage > previous_garbage:
936      logging.error(
937          "The decorated test created work for Python's garbage collector, "
938          "likely due to a reference cycle. New objects in cycle(s):")
939      for i, obj in enumerate(gc.garbage[previous_garbage:]):
940        try:
941          logging.error("Object %d of %d", i,
942                        len(gc.garbage) - previous_garbage)
943
944          def _safe_object_str(obj):
945            return "<%s %d>" % (obj.__class__.__name__, id(obj))
946
947          logging.error("  Object type: %s", _safe_object_str(obj))
948          logging.error(
949              "  Referrer types: %s", ", ".join(
950                  [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
951          logging.error(
952              "  Referent types: %s", ", ".join(
953                  [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
954          logging.error("  Object attribute names: %s", dir(obj))
955          logging.error("  Object __str__:")
956          logging.error(obj)
957          logging.error("  Object __repr__:")
958          logging.error(repr(obj))
959        except Exception:  # pylint: disable=broad-except
960          logging.error("(Exception while printing object)")
961
962    # When garbage is created, this call can help identify reference cycles,
963    # which are typically the cause of such garbage.
964    if new_garbage > previous_garbage:
965      for i in range(previous_garbage, new_garbage):
966        if _find_reference_cycle(gc.garbage, i):
967          break
968
969    # This will fail if any garbage has been created, typically because of a
970    # reference cycle.
971    self.assertEqual(previous_garbage, new_garbage)
972    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
973    # be nice to be able to decorate arbitrary tests in a large test suite and
974    # not hold on to every object in other tests.
975    gc.set_debug(previous_debug_flags)
976    gc.enable()
977    return result
978
979  return decorator
980
981
982def _combine_named_parameters(**kwargs):
983  """Generate combinations based on its keyword arguments.
984
985  Two sets of returned combinations can be concatenated using +.  Their product
986  can be computed using `times()`.
987
988  Args:
989    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
990      `option=the_only_possibility`.
991
992  Returns:
993    a list of dictionaries for each combination. Keys in the dictionaries are
994    the keyword argument names.  Each key has one value - one of the
995    corresponding keyword argument values.
996  """
997  sort_by_key = lambda k: k[0]
998  combinations = []
999  for key, values in sorted(kwargs.items(), key=sort_by_key):
1000    if not isinstance(values, list):
1001      values = [values]
1002    combinations.append([(key, value) for value in values])
1003
1004  return [OrderedDict(result) for result in itertools.product(*combinations)]
1005
1006
1007def generate_combinations_with_testcase_name(**kwargs):
1008  """Generate combinations based on its keyword arguments using combine().
1009
1010  This function calls combine() and appends a testcase name to the list of
1011  dictionaries returned. The 'testcase_name' key is a required for named
1012  parameterized tests.
1013
1014  Args:
1015    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1016      `option=the_only_possibility`.
1017
1018  Returns:
1019    a list of dictionaries for each combination. Keys in the dictionaries are
1020    the keyword argument names.  Each key has one value - one of the
1021    corresponding keyword argument values.
1022  """
1023  combinations = _combine_named_parameters(**kwargs)
1024  named_combinations = []
1025  for combination in combinations:
1026    assert isinstance(combination, OrderedDict)
1027    name = "".join([
1028        "_{}_{}".format("".join(filter(str.isalnum, key)),
1029                        "".join(filter(str.isalnum, str(value))))
1030        for key, value in combination.items()
1031    ])
1032    named_combinations.append(
1033        OrderedDict(
1034            list(combination.items()) +
1035            [("testcase_name", "_test{}".format(name))]))
1036
1037  return named_combinations
1038
1039
1040def run_all_in_graph_and_eager_modes(cls):
1041  """Execute all test methods in the given class with and without eager."""
1042  base_decorator = run_in_graph_and_eager_modes
1043  for name in dir(cls):
1044    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1045        name.startswith("testSkipEager") or
1046        name.startswith("test_skip_eager") or
1047        name == "test_session"):
1048      continue
1049    value = getattr(cls, name, None)
1050    if callable(value):
1051      setattr(cls, name, base_decorator(value))
1052  return cls
1053
1054
1055def build_as_function_and_v1_graph(func=None):
1056  """Run a test case in v1 graph mode and inside tf.function in eager mode.
1057
1058  WARNING: This decorator can only be used in test cases that statically checks
1059  generated graph. Attempting to evaluate graph or function results via.
1060  session.run() or self.evaluate() will fail.
1061
1062  WARNING: This decorator can only be used for test cases that inherit from
1063  absl.testing.parameterized.TestCase.
1064
1065  Args:
1066    func: Test case function to be decorated.
1067
1068  Returns:
1069    Decorated test case function.
1070  """
1071
1072  def decorator(f):
1073    if tf_inspect.isclass(f):
1074      raise ValueError(
1075          "`run_in_graph_mode_and_function` only supports test methods.")
1076
1077    @parameterized.named_parameters(("_v1_graph", "v1_graph"),
1078                                    ("_function", "function"))
1079    @functools.wraps(f)
1080    def decorated(self, run_mode, *args, **kwargs):
1081      if run_mode == "v1_graph":
1082        with ops.Graph().as_default():
1083          f(self, *args, **kwargs)
1084      elif run_mode == "function":
1085
1086        @def_function.function
1087        def function_in_eager():
1088          f(self, *args, **kwargs)
1089
1090        # Create a new graph for the eagerly executed version of this test for
1091        # better isolation.
1092        graph_for_eager_test = ops.Graph()
1093        with graph_for_eager_test.as_default(), context.eager_mode():
1094          function_in_eager()
1095        ops.dismantle_graph(graph_for_eager_test)
1096      else:
1097        return ValueError("Unknown run mode %s" % run_mode)
1098
1099    return decorated
1100
1101  if func is not None:
1102    return decorator(func)
1103
1104  return decorator
1105
1106
1107def run_in_async_and_sync_mode(f):
1108  """Execute the test in async mode and sync mode."""
1109
1110  @parameterized.named_parameters([("Async", True), ("", False)])
1111  @functools.wraps(f)
1112  def decorator(self, async_mode, *args, **kwargs):
1113    if async_mode:
1114      with context.execution_mode(context.ASYNC):
1115        f(self, *args, **kwargs)
1116    else:
1117      with context.execution_mode(context.SYNC):
1118        f(self, *args, **kwargs)
1119  return decorator
1120
1121
1122def run_in_graph_and_eager_modes(func=None,
1123                                 config=None,
1124                                 use_gpu=True,
1125                                 assert_no_eager_garbage=False):
1126  """Execute the decorated test with and without enabling eager execution.
1127
1128  This function returns a decorator intended to be applied to test methods in
1129  a `tf.test.TestCase` class. Doing so will cause the contents of the test
1130  method to be executed twice - once normally, and once with eager execution
1131  enabled. This allows unittests to confirm the equivalence between eager
1132  and graph execution (see `tf.compat.v1.enable_eager_execution`).
1133
1134  For example, consider the following unittest:
1135
1136  ```python
1137  class MyTests(tf.test.TestCase):
1138
1139    @run_in_graph_and_eager_modes
1140    def test_foo(self):
1141      x = tf.constant([1, 2])
1142      y = tf.constant([3, 4])
1143      z = tf.add(x, y)
1144      self.assertAllEqual([4, 6], self.evaluate(z))
1145
1146  if __name__ == "__main__":
1147    tf.test.main()
1148  ```
1149
1150  This test validates that `tf.add()` has the same behavior when computed with
1151  eager execution enabled as it does when constructing a TensorFlow graph and
1152  executing the `z` tensor in a session.
1153
1154  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1155  `run_in_graph_and_eager_modes` are available decorators for different
1156  v1/v2/eager/graph combinations.
1157
1158
1159  Args:
1160    func: function to be annotated. If `func` is None, this method returns a
1161      decorator the can be applied to a function. If `func` is not None this
1162      returns the decorator applied to `func`.
1163    config: An optional config_pb2.ConfigProto to use to configure the session
1164      when executing graphs.
1165    use_gpu: If True, attempt to run as many operations as possible on GPU.
1166    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
1167      collector and asserts that no extra garbage has been created when running
1168      the test with eager execution enabled. This will fail if there are
1169      reference cycles (e.g. a = []; a.append(a)). Off by default because some
1170      tests may create garbage for legitimate reasons (e.g. they define a class
1171      which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
1172      Python interpreters (meaning that tests which rely on objects being
1173      collected elsewhere in the unit test file will not work). Additionally,
1174      checks that nothing still has a reference to Tensors that the test
1175      allocated.
1176
1177  Returns:
1178    Returns a decorator that will run the decorated test method twice:
1179    once by constructing and executing a graph in a session and once with
1180    eager execution enabled.
1181  """
1182
1183  def decorator(f):
1184    if tf_inspect.isclass(f):
1185      raise ValueError(
1186          "`run_in_graph_and_eager_modes` only supports test methods. "
1187          "Did you mean to use `run_all_in_graph_and_eager_modes`?")
1188
1189    def decorated(self, *args, **kwargs):
1190      try:
1191        with context.graph_mode():
1192          with self.test_session(use_gpu=use_gpu, config=config):
1193            f(self, *args, **kwargs)
1194      except unittest.case.SkipTest:
1195        pass
1196
1197      def run_eagerly(self, **kwargs):
1198        if not use_gpu:
1199          with ops.device("/device:CPU:0"):
1200            f(self, *args, **kwargs)
1201        else:
1202          f(self, *args, **kwargs)
1203
1204      if assert_no_eager_garbage:
1205        ops.reset_default_graph()
1206        run_eagerly = assert_no_new_tensors(
1207            assert_no_garbage_created(run_eagerly))
1208
1209      # This decorator runs the wrapped test twice.
1210      # Reset the test environment between runs.
1211      self.tearDown()
1212      self._tempdir = None
1213      # Create a new graph for the eagerly executed version of this test for
1214      # better isolation.
1215      graph_for_eager_test = ops.Graph()
1216      with graph_for_eager_test.as_default(), context.eager_mode():
1217        self.setUp()
1218        run_eagerly(self, **kwargs)
1219      ops.dismantle_graph(graph_for_eager_test)
1220
1221    return tf_decorator.make_decorator(f, decorated)
1222
1223  if func is not None:
1224    return decorator(func)
1225
1226  return decorator
1227
1228
1229def py_func_if_in_function(f):
1230
1231  def decorated(*args, **kwds):
1232    if not ops.inside_function():
1233      return f(*args, **kwds)
1234
1235    tensor_args = []
1236    tensor_indices = []
1237    for i, arg in enumerate(args):
1238      if isinstance(arg, (ops.Tensor, variables.Variable)):
1239        tensor_args.append(arg)
1240        tensor_indices.append(i)
1241
1242    def inner_f(*inner_tensor_args):
1243      my_args = list(args)
1244      for i, n in zip(tensor_indices, inner_tensor_args):
1245        my_args[i] = n
1246      return f(*my_args, **kwds)
1247
1248    return script_ops.py_func(inner_f, tensor_args, [])
1249
1250  return tf_decorator.make_decorator(f, decorated)
1251
1252
1253def also_run_as_tf_function(f):
1254  """Runs the decorated test twice--once as is, once inside a tf.function.
1255
1256  This allows you to run a test both in eager execution and inside a
1257  tf.function, exercising the two execution modes supported in tf 2.0. The test
1258  assertions are automatically done inside tf.py_funcs, and tf.function ensures
1259  that they run in the proper order and with the proper side effects.
1260
1261  Currently variable creation is not supported in tests annotated with this
1262  decorator since it's tricky to ensure the variable doesn't get repeatedly
1263  created when retracing the tf.function.
1264
1265  Args:
1266    f: the test method to be decorated
1267
1268  Returns:
1269    The decorated test method, which will run both in eager and inside a
1270    tf.function.
1271  """
1272
1273  def decorated(*args, **kwds):
1274
1275    def bound_f():
1276      f(*args, **kwds)
1277
1278    with context.eager_mode():
1279      # Running in eager mode
1280      bound_f()
1281      # Running as TF function
1282      # TODO(b/121143941): Remove the autograph override.
1283      def_function.function(bound_f, autograph=False)()
1284
1285  return decorated
1286
1287
1288def deprecated_graph_mode_only(func=None):
1289  """Execute the decorated test in graph mode.
1290
1291  This function returns a decorator intended to be applied to tests that are not
1292  compatible with eager mode. When this decorator is applied, the test body will
1293  be run in an environment where API calls construct graphs instead of executing
1294  eagerly.
1295
1296  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1297  `run_in_graph_and_eager_modes` are available decorators for different
1298  v1/v2/eager/graph combinations.
1299
1300  Args:
1301    func: function to be annotated. If `func` is None, this method returns a
1302      decorator the can be applied to a function. If `func` is not None this
1303      returns the decorator applied to `func`.
1304
1305  Returns:
1306    Returns a decorator that will run the decorated test method in graph mode.
1307  """
1308
1309  def decorator(f):
1310    if tf_inspect.isclass(f):
1311      setup = f.__dict__.get("setUp")
1312      if setup is not None:
1313        setattr(f, "setUp", decorator(setup))
1314
1315      for name, value in f.__dict__.copy().items():
1316        if (callable(value) and
1317            name.startswith(unittest.TestLoader.testMethodPrefix)):
1318          setattr(f, name, decorator(value))
1319
1320      return f
1321
1322    def decorated(self, *args, **kwargs):
1323      if context.executing_eagerly():
1324        with context.graph_mode():
1325          return f(self, *args, **kwargs)
1326      else:
1327        return f(self, *args, **kwargs)
1328
1329    return decorated
1330
1331  if func is not None:
1332    return decorator(func)
1333
1334  return decorator
1335
1336
1337run_deprecated_v1 = deprecated_graph_mode_only
1338
1339
1340def run_all_in_deprecated_graph_mode_only(cls):
1341  """Execute all tests in a class in graph mode."""
1342  base_decorator = deprecated_graph_mode_only
1343  for name in dir(cls):
1344    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1345        name == "test_session"):
1346      continue
1347    value = getattr(cls, name, None)
1348    if callable(value):
1349      setattr(cls, name, base_decorator(value))
1350  return cls
1351
1352
1353def run_v1_only(reason, func=None):
1354  """Execute the decorated test only if running in v1 mode.
1355
1356  This function is intended to be applied to tests that exercise v1 only
1357  functionality. If the test is run in v2 mode it will simply be skipped.
1358
1359  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1360  `run_in_graph_and_eager_modes` are available decorators for different
1361  v1/v2/eager/graph combinations.
1362
1363  Args:
1364    reason: string giving a reason for limiting the test to v1 only.
1365    func: function to be annotated. If `func` is None, this method returns a
1366      decorator the can be applied to a function. If `func` is not None this
1367      returns the decorator applied to `func`.
1368
1369  Returns:
1370    Returns a decorator that will conditionally skip the decorated test method.
1371  """
1372  if not isinstance(reason, str):
1373    raise ValueError("'reason' should be string, got {}".format(type(reason)))
1374
1375  def decorator(f):
1376    if tf_inspect.isclass(f):
1377      # To skip an entire test suite class, we only decorate the setUp method
1378      # to skip all tests. There are cases when setUp is not defined (not
1379      # overridden in subclasses of TestCase, so not available in f.__dict__
1380      # below). For those cases, we walk the method resolution order list and
1381      # pick the first setUp method we find (usually this should be the one in
1382      # the parent class since that's the TestCase class).
1383      for cls in type.mro(f):
1384        setup = cls.__dict__.get("setUp")
1385        if setup is not None:
1386          setattr(f, "setUp", decorator(setup))
1387          break
1388
1389      return f
1390    else:
1391      # If f is just a function, just create a decorator for it and return it
1392      def decorated(self, *args, **kwargs):
1393        if tf2.enabled():
1394          self.skipTest(reason)
1395
1396        return f(self, *args, **kwargs)
1397
1398      return decorated
1399
1400  if func is not None:
1401    return decorator(func)
1402
1403  return decorator
1404
1405
1406def run_v2_only(func=None):
1407  """Execute the decorated test only if running in v2 mode.
1408
1409  This function is intended to be applied to tests that exercise v2 only
1410  functionality. If the test is run in v1 mode it will simply be skipped.
1411
1412  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1413  `run_in_graph_and_eager_modes` are available decorators for different
1414  v1/v2/eager/graph combinations.
1415
1416  Args:
1417    func: function to be annotated. If `func` is None, this method returns a
1418      decorator the can be applied to a function. If `func` is not None this
1419      returns the decorator applied to `func`.
1420
1421  Returns:
1422    Returns a decorator that will conditionally skip the decorated test method.
1423  """
1424
1425  def decorator(f):
1426    if tf_inspect.isclass(f):
1427      raise ValueError("`run_v2_only` only supports test methods.")
1428
1429    def decorated(self, *args, **kwargs):
1430      if not tf2.enabled():
1431        self.skipTest("Test is only compatible with v2")
1432
1433      return f(self, *args, **kwargs)
1434
1435    return decorated
1436
1437  if func is not None:
1438    return decorator(func)
1439
1440  return decorator
1441
1442
1443def run_gpu_only(func=None):
1444  """Execute the decorated test only if a GPU is available.
1445
1446  This function is intended to be applied to tests that require the presence
1447  of a GPU. If a GPU is absent, it will simply be skipped.
1448
1449  Args:
1450    func: function to be annotated. If `func` is None, this method returns a
1451      decorator the can be applied to a function. If `func` is not None this
1452      returns the decorator applied to `func`.
1453
1454  Returns:
1455    Returns a decorator that will conditionally skip the decorated test method.
1456  """
1457
1458  def decorator(f):
1459    if tf_inspect.isclass(f):
1460      raise ValueError("`run_gpu_only` only supports test methods.")
1461
1462    def decorated(self, *args, **kwargs):
1463      if not is_gpu_available():
1464        self.skipTest("Test requires GPU")
1465
1466      return f(self, *args, **kwargs)
1467
1468    return decorated
1469
1470  if func is not None:
1471    return decorator(func)
1472
1473  return decorator
1474
1475
1476def run_cuda_only(func=None):
1477  """Execute the decorated test only if a GPU is available.
1478
1479  This function is intended to be applied to tests that require the presence
1480  of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped.
1481
1482  Args:
1483    func: function to be annotated. If `func` is None, this method returns a
1484      decorator the can be applied to a function. If `func` is not None this
1485      returns the decorator applied to `func`.
1486
1487  Returns:
1488    Returns a decorator that will conditionally skip the decorated test method.
1489  """
1490
1491  def decorator(f):
1492    if tf_inspect.isclass(f):
1493      raise ValueError("`run_cuda_only` only supports test methods.")
1494
1495    def decorated(self, *args, **kwargs):
1496      if not is_gpu_available(cuda_only=True):
1497        self.skipTest("Test requires CUDA GPU")
1498
1499      return f(self, *args, **kwargs)
1500
1501    return decorated
1502
1503  if func is not None:
1504    return decorator(func)
1505
1506  return decorator
1507
1508
1509def with_forward_compatibility_horizons(*horizons):
1510  """Executes the decorated test with the specified forward-compat horizons.
1511
1512  Args:
1513    *horizons: A list of (year, month, day) tuples.  If the list includes
1514      `None`, then the test will also be run with no forward-compatibility
1515      horizon set.
1516
1517  Returns:
1518    A decorator that will execute the test with the specified horizons.
1519  """
1520  if not horizons:
1521    raise ValueError("Expected at least one horizon.")
1522  for horizon in horizons:
1523    if not ((horizon is None) or
1524            (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))):
1525      raise ValueError("Bad horizon value: %r" % horizon)
1526
1527  def decorator(f):
1528    if tf_inspect.isclass(f):
1529      raise ValueError("`with_forward_compatibility_horizons` only "
1530                       "supports test methods.")
1531    def decorated(self, *args, **kwargs):
1532      for horizon in horizons:
1533        if horizon is None:
1534          f(self, *args, **kwargs)
1535        else:
1536          (year, month, day) = horizon
1537          with forward_compatibility_horizon(year, month, day):
1538            f(self, *args, **kwargs)
1539    return decorated
1540
1541  return decorator
1542
1543
1544@deprecation.deprecated(None,
1545                        "Use `tf.config.list_physical_devices('GPU')` instead.")
1546@tf_export("test.is_gpu_available")
1547def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
1548  """Returns whether TensorFlow can access a GPU.
1549
1550  Warning: if a non-GPU version of the package is installed, the function would
1551  also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow
1552  was build with CUDA support.
1553
1554  For example,
1555  >>> gpu_available = tf.test.is_gpu_available()
1556  >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True)
1557  >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0))
1558
1559  Args:
1560    cuda_only: limit the search to CUDA GPUs.
1561    min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
1562      CUDA compute capability required, or None if no requirement.
1563
1564  Note that the keyword arg name "cuda_only" is misleading (since routine will
1565  return true when a GPU device is available irrespective of whether TF was
1566  built with CUDA support or ROCm support. However no changes here because
1567
1568  ++ Changing the name "cuda_only" to something more generic would break
1569     backward compatibility
1570
1571  ++ Adding an equivalent "rocm_only" would require the implementation check
1572     the build type. This in turn would require doing the same for CUDA and thus
1573     potentially break backward compatibility
1574
1575  ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility,
1576     but would require most (if not all) callers to update the call to use
1577     "cuda_or_rocm_only" instead of "cuda_only"
1578
1579  Returns:
1580    True if a GPU device of the requested kind is available.
1581  """
1582
1583  # This was needed earlier when we had support for SYCL in TensorFlow.
1584  del cuda_only
1585
1586  try:
1587    for local_device in device_lib.list_local_devices():
1588      if local_device.device_type == "GPU":
1589        gpu_info = gpu_util.compute_capability_from_device_desc(local_device)
1590        cc = gpu_info.compute_capability or (0, 0)
1591        if not min_cuda_compute_capability or cc >= min_cuda_compute_capability:
1592          return True
1593    return False
1594  except errors_impl.NotFoundError as e:
1595    if not all(x in str(e) for x in ["CUDA", "not find"]):
1596      raise e
1597    else:
1598      logging.error(str(e))
1599      return False
1600
1601
1602@contextlib.contextmanager
1603def device(use_gpu):
1604  """Uses gpu when requested and available."""
1605  if use_gpu and is_gpu_available():
1606    dev = "/device:GPU:0"
1607  else:
1608    dev = "/device:CPU:0"
1609  with ops.device(dev):
1610    yield
1611
1612
1613@contextlib.contextmanager
1614def use_gpu():
1615  """Uses gpu when requested and available."""
1616  with device(use_gpu=True):
1617    yield
1618
1619
1620@contextlib.contextmanager
1621def force_gpu():
1622  """Force the gpu to be used."""
1623  with ops.device("/device:GPU:0"):
1624    yield
1625
1626
1627@contextlib.contextmanager
1628def force_cpu():
1629  """Force the cpu to be used."""
1630  with ops.device("/device:CPU:0"):
1631    yield
1632
1633
1634class CapturedWrites(object):
1635  """A utility class to load the captured writes made to a stream."""
1636
1637  def __init__(self, capture_location):
1638    self.capture_location = capture_location
1639
1640  def contents(self):
1641    """Get the captured writes as a single string."""
1642    with open(self.capture_location) as tmp_file:
1643      output_data = "".join(tmp_file.readlines())
1644    return output_data
1645
1646
1647class FakeEagerSession(object):
1648  """Fake session so tests that conditionally use placeholders can use eager.
1649
1650  There are a number of tests that conditionally use placeholders for shape
1651  inference. The pattern is demonstrated here:
1652
1653  ```python
1654  with self.cached_session() as sess:
1655    if static_shape:
1656      y = math_ops.matmul(x, ...)
1657      feed_dict = {}
1658    else:
1659      x_ph = array_ops.placeholder(...)
1660      y = math_ops.matmul(x_ph, ...)
1661      feed_dict = {x_ph: x}
1662    val = sess.run(y, feed_dict=feed_dict)
1663  ```
1664
1665  Since the feed_dict is empty when not using placeholders we should be able to
1666  call self.evaluate(), however this requires rewriting the test case.
1667  This class should be considered a stop-gap solution to get tests running with
1668  eager with minimal changes to the actual test.
1669  """
1670
1671  def __init__(self, test_case):
1672    self._test_case = test_case
1673
1674  def run(self, fetches, *args, **kwargs):
1675    """Evaluate `fetches`.
1676
1677    Fail if additional args are specified.
1678
1679    Args:
1680      fetches: A Tensor or a nested list/tuple of Tensors.
1681      *args: Positional arguments
1682      **kwargs: Keyword arguments
1683
1684    Raises:
1685      RuntimeError: If args or kwargs are specified.
1686
1687    Returns:
1688      Tensors as numpy values.
1689    """
1690    feed_dict = kwargs.pop("feed_dict", {})
1691    if feed_dict:
1692      raise RuntimeError(
1693          "feed_dict is not supported when eager execution is enabled "
1694          "(in this case, sess.run(t) is shorthand for t.numpy()")
1695
1696    if args or kwargs:
1697      raise RuntimeError(
1698          "Optional args are not supported when eager execution is enabled "
1699          "(in this case, sess.run(t) is shorthand for t.numpy()")
1700
1701    return self._test_case.evaluate(fetches)
1702
1703
1704class ErrorLoggingSession(session.Session):
1705  """Wrapper around a Session that logs errors in run()."""
1706
1707  def run(self, *args, **kwargs):
1708    try:
1709      return super(ErrorLoggingSession, self).run(*args, **kwargs)
1710    except Exception as e:  # pylint: disable=broad-except
1711      # Note: disable the logging for OutOfRangeError, which makes the output
1712      # of tf.data tests hard to read, because OutOfRangeError is used as the
1713      # signal completion
1714      if not isinstance(e, errors.OutOfRangeError):
1715        logging.error(str(e))
1716      raise
1717
1718
1719def disable_cudnn_autotune(func):
1720  """Disable autotuning during the call to this function.
1721
1722  Some tests want to base assertions on a graph being isomorphic with a copy.
1723  To ensure this, this decorator disables autotuning.
1724
1725  Args:
1726    func: Function to run with CuDNN autotuning turned off.
1727
1728  Returns:
1729    Decorated function.
1730  """
1731
1732  def decorator(f):
1733
1734    def decorated(self, *args, **kwargs):
1735      original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE")
1736      os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false"
1737      original_xla_flags = os.environ.get("XLA_FLAGS")
1738      new_xla_flags = "--xla_gpu_autotune_level=0"
1739      if original_xla_flags:
1740        new_xla_flags = original_xla_flags + " " + new_xla_flags
1741      os.environ["XLA_FLAGS"] = new_xla_flags
1742
1743      result = f(self, *args, **kwargs)
1744
1745      if (original_tf_cudnn_use_autotune is None):
1746        del os.environ["TF_CUDNN_USE_AUTOTUNE"]
1747      else:
1748        os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune
1749      if (original_xla_flags is None):
1750        del os.environ["XLA_FLAGS"]
1751      else:
1752        os.environ["XLA_FLAGS"] = original_xla_flags
1753
1754      return result
1755
1756    return decorated
1757
1758  if func is not None:
1759    return decorator(func)
1760
1761  return decorator
1762
1763
1764# The description is just for documentation purposes.
1765def enable_tf_xla_constant_folding(description):
1766
1767  if not isinstance(description, str):
1768    raise ValueError("'description' should be string, got {}".format(
1769        type(description)))
1770
1771  def enable_tf_xla_constant_folding_impl(func):
1772    """Enable constant folding during the call to this function.
1773
1774    Some tests fail without constant folding.
1775
1776    Args:
1777      func: Function to run with constant folding turned on.
1778
1779    Returns:
1780      Decorated function.
1781    """
1782
1783    def decorator(f):
1784
1785      def decorated(self, *args, **kwargs):
1786        original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
1787        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
1788        result = f(self, *args, **kwargs)
1789        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
1790        return result
1791
1792      return decorated
1793
1794    if func is not None:
1795      return decorator(func)
1796
1797    return decorator
1798
1799  return enable_tf_xla_constant_folding_impl
1800
1801
1802# Updates test function by selectively disabling it.
1803def _disable_test(execute_func):
1804
1805  def disable_test_impl(func):
1806
1807    def decorator(func):
1808
1809      def decorated(self, *args, **kwargs):
1810        if execute_func:
1811          return func(self, *args, **kwargs)
1812
1813      return decorated
1814
1815    if func is not None:
1816      return decorator(func)
1817
1818    return decorator
1819
1820  return disable_test_impl
1821
1822
1823# The description is just for documentation purposes.
1824def disable_xla(description):  # pylint: disable=unused-argument
1825  """Execute the test method only if xla is not enabled."""
1826  execute_func = not is_xla_enabled()
1827  return _disable_test(execute_func)
1828
1829
1830# The description is just for documentation purposes.
1831def disable_mlir_bridge(description):  # pylint: disable=unused-argument
1832  """Execute the test method only if MLIR bridge is not enabled."""
1833  execute_func = not is_mlir_bridge_enabled()
1834  return _disable_test(execute_func)
1835
1836
1837# The description is just for documentation purposes.
1838def disable_tfrt(unused_description):
1839
1840  def disable_tfrt_impl(cls_or_func):
1841    """Execute the test only if tfrt is not enabled."""
1842
1843    if tf_inspect.isclass(cls_or_func):
1844      if tfrt_utils.enabled():
1845        return None
1846      else:
1847        return cls_or_func
1848    else:
1849      def decorator(func):
1850
1851        def decorated(self, *args, **kwargs):
1852          if tfrt_utils.enabled():
1853            return
1854          else:
1855            return func(self, *args, **kwargs)
1856
1857        return decorated
1858
1859      if cls_or_func is not None:
1860        return decorator(cls_or_func)
1861
1862      return decorator
1863
1864  return disable_tfrt_impl
1865
1866
1867def for_all_test_methods(decorator, *args, **kwargs):
1868  """Generate class-level decorator from given method-level decorator.
1869
1870  It is expected for the given decorator to take some arguments and return
1871  a method that is then called on the test method to produce a decorated
1872  method.
1873
1874  Args:
1875    decorator: The decorator to apply.
1876    *args: Positional arguments
1877    **kwargs: Keyword arguments
1878  Returns: Function that will decorate a given classes test methods with the
1879    decorator.
1880  """
1881
1882  def all_test_methods_impl(cls):
1883    """Apply decorator to all test methods in class."""
1884    for name in dir(cls):
1885      value = getattr(cls, name)
1886      if callable(value) and name.startswith(
1887          "test") and (name != "test_session"):
1888        setattr(cls, name, decorator(*args, **kwargs)(value))
1889    return cls
1890
1891  return all_test_methods_impl
1892
1893
1894# The description is just for documentation purposes.
1895def no_xla_auto_jit(description):  # pylint: disable=unused-argument
1896  """This test is not intended to be run with XLA auto jit enabled."""
1897  execute_func = not is_xla_enabled()
1898  return _disable_test(execute_func)
1899
1900
1901# The description is just for documentation purposes.
1902def xla_allow_fallback(description):  # pylint: disable=unused-argument
1903
1904  def xla_allow_fallback_impl(func):
1905    """Allow fallback to TF even though testing xla."""
1906
1907    def decorator(func):
1908
1909      def decorated(self, *args, **kwargs):
1910        if is_xla_enabled():
1911          # Update the global XLABuildOpsPassFlags to enable lazy compilation,
1912          # which allows the compiler to fall back to TF classic. Remember the
1913          # old value so that we can reset it.
1914          old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
1915          result = func(self, *args, **kwargs)
1916          pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
1917          return result
1918        else:
1919          return func(self, *args, **kwargs)
1920
1921      return decorated
1922
1923    if func is not None:
1924      return decorator(func)
1925
1926    return decorator
1927
1928  return xla_allow_fallback_impl
1929
1930
1931# The description is just for documentation purposes.
1932def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
1933  """Execute test with TensorFloat-32 disabled.
1934
1935  While almost every real-world deep learning model runs fine with
1936  TensorFloat-32, many tests use assertAllClose or similar methods.
1937  TensorFloat-32 matmuls typically will cause such methods to fail with the
1938  default tolerances.
1939
1940  Args:
1941    description: A description used for documentation purposes, describing why
1942      the test requires TensorFloat-32 to be disabled.
1943
1944  Returns:
1945    Decorator which runs a test with TensorFloat-32 disabled.
1946  """
1947
1948  def decorator(f):
1949
1950    @functools.wraps(f)
1951    def decorated(self, *args, **kwargs):
1952      allowed = config.tensor_float_32_execution_enabled()
1953      try:
1954        config.enable_tensor_float_32_execution(False)
1955        f(self, *args, **kwargs)
1956      finally:
1957        config.enable_tensor_float_32_execution(allowed)
1958
1959    return decorated
1960
1961  return decorator
1962
1963
1964# The description is just for documentation purposes.
1965def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
1966  """Execute all tests in a class with TensorFloat-32 disabled."""
1967  return for_all_test_methods(run_without_tensor_float_32, description)
1968
1969
1970def matmul_without_tf32(a, b, *args, **kwargs):
1971  """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled.
1972
1973  This effectively runs matmul without TensorFloat-32. It should only be used in
1974  tests when verifying some other op or functions works correctly, e.g. to test
1975  `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In
1976  such cases, the matmul itself is not being tested so it's OK to run it with
1977  higher precision.
1978
1979  If a matmul itself is being tested, or some other op which uses matmul, use
1980  `run_without_tensor_float_32` instead.
1981
1982  This also casts complex64 inputs to complex128, since TensorFloat-32 can also
1983  be used with complex64
1984
1985  Args:
1986    a: First input to tf.linalg.matmul
1987    b: Second input to tf.linalg.matmul
1988    args: Other positional arguments to tf.linalg.matmul
1989    **kwargs: Other keyword arguments to tf.linalg.matmul
1990
1991  Returns:
1992    A tensor with the same type as `a`.
1993  """
1994  if config.tensor_float_32_execution_enabled() and a.dtype == "float32":
1995    a = math_ops.cast(a, "float64")
1996    b = math_ops.cast(b, "float64")
1997    ret = math_ops.matmul(a, b, *args, **kwargs)
1998    return math_ops.cast(ret, a.dtype)
1999  elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
2000    a = math_ops.cast(a, "complex128")
2001    b = math_ops.cast(b, "complex128")
2002    ret = math_ops.matmul(a, b, *args, **kwargs)
2003    return math_ops.cast(ret, a.dtype)
2004  else:
2005    return math_ops.matmul(a, b, *args, **kwargs)
2006
2007
2008class EagerSessionWarner(object):
2009
2010  def __getattr__(self, attr):
2011    raise AttributeError(
2012        "Trying to access properties or call methods on the result of "
2013        "self.session(), self.cached_session(), etc while eager execution "
2014        "is enabled. If you're porting this test case to TF 2.0, either "
2015        "adapt the test to work with eager execution or insert a call to "
2016        "tf.disable_eager_execution() in the main() function of this test "
2017        "file.")
2018
2019
2020@tf_export("test.TestCase")
2021class TensorFlowTestCase(googletest.TestCase):
2022  """Base class for tests that need to test TensorFlow."""
2023
2024  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
2025    super(TensorFlowTestCase, self).__init__(methodName)
2026    if is_xla_enabled():
2027      pywrap_tf_session.TF_SetXlaAutoJitMode("2")
2028      pywrap_tf_session.TF_SetXlaMinClusterSize(1)
2029      pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
2030      pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
2031      # Constant folding secretly runs code on TF:Classic CPU, so we also
2032      # disable it here.
2033      pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
2034
2035    # Check if the mlir bridge has been explicitly enabled or disabled. If
2036    # is_mlir_bridge_enabled() returns None, the user did not explictly enable
2037    # or disable the bridge so do not update enable_mlir_bridge.
2038    if is_mlir_bridge_enabled():
2039      context.context().enable_mlir_bridge = True
2040    elif is_mlir_bridge_enabled() is not None:
2041      context.context().enable_mlir_bridge = False
2042
2043    self._threads = []
2044    self._tempdir = None
2045    self._cached_session = None
2046    self._test_start_time = None
2047    # This flag provides the ability to control whether the graph mode gets
2048    # initialized for TF1 or not. Initializing for TF1, which is what was
2049    # happening earlier, was preventing enablement of 'eager mode' in the test.
2050    self._set_default_seed = True
2051
2052  def setUp(self):
2053    super(TensorFlowTestCase, self).setUp()
2054    self._ClearCachedSession()
2055    random.seed(random_seed.DEFAULT_GRAPH_SEED)
2056    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
2057    # Note: The following line is necessary because some test methods may error
2058    # out from within nested graph contexts (e.g., via assertRaises and
2059    # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
2060    # under certain versions of Python. That would cause
2061    # ops.reset_default_graph() to throw an exception if the stack were not
2062    # cleared first.
2063    ops._default_graph_stack.reset()  # pylint: disable=protected-access
2064    ops.reset_default_graph()
2065    if self._set_default_seed:
2066      random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
2067    # Reset summary writer in case another test used set_as_default() with their
2068    # summary writer.
2069    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
2070    summary_state.writer = None
2071
2072    # Avoiding calling setUp() for the poorly named test_session method.
2073    if self.id().endswith(".test_session"):
2074      self.skipTest("Not a test.")
2075
2076    self._test_start_time = time.time()
2077
2078  def tearDown(self):
2079    # If a subclass overrides setUp and doesn't call the parent class's setUp,
2080    # then we may not have set the start time.
2081    if self._test_start_time is not None:
2082      logging.info("time(%s): %ss", self.id(),
2083                   round(time.time() - self._test_start_time, 2))
2084
2085    for thread in self._threads:
2086      thread.check_termination()
2087
2088    self._ClearCachedSession()
2089    super(TensorFlowTestCase, self).tearDown()
2090
2091  def _ClearCachedSession(self):
2092    if self._cached_session is not None:
2093      self._cached_session.close()
2094      self._cached_session = None
2095
2096  def get_temp_dir(self):
2097    """Returns a unique temporary directory for the test to use.
2098
2099    If you call this method multiple times during in a test, it will return the
2100    same folder. However, across different runs the directories will be
2101    different. This will ensure that across different runs tests will not be
2102    able to pollute each others environment.
2103    If you need multiple unique directories within a single test, you should
2104    use tempfile.mkdtemp as follows:
2105      tempfile.mkdtemp(dir=self.get_temp_dir()):
2106
2107    Returns:
2108      string, the path to the unique temporary directory created for this test.
2109    """
2110    if not self._tempdir:
2111      self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
2112    return self._tempdir
2113
2114  @contextlib.contextmanager
2115  def captureWritesToStream(self, stream):
2116    """A context manager that captures the writes to a given stream.
2117
2118    This context manager captures all writes to a given stream inside of a
2119    `CapturedWrites` object. When this context manager is created, it yields
2120    the `CapturedWrites` object. The captured contents can be accessed  by
2121    calling `.contents()` on the `CapturedWrites`.
2122
2123    For this function to work, the stream must have a file descriptor that
2124    can be modified using `os.dup` and `os.dup2`, and the stream must support
2125    a `.flush()` method. The default python sys.stdout and sys.stderr are
2126    examples of this. Note that this does not work in Colab or Jupyter
2127    notebooks, because those use alternate stdout streams.
2128
2129    Example:
2130    ```python
2131    class MyOperatorTest(test_util.TensorFlowTestCase):
2132      def testMyOperator(self):
2133        input = [1.0, 2.0, 3.0, 4.0, 5.0]
2134        with self.captureWritesToStream(sys.stdout) as captured:
2135          result = MyOperator(input).eval()
2136        self.assertStartsWith(captured.contents(), "This was printed.")
2137    ```
2138
2139    Args:
2140      stream: The stream whose writes should be captured. This stream must have
2141        a file descriptor, support writing via using that file descriptor, and
2142        must have a `.flush()` method.
2143
2144    Yields:
2145      A `CapturedWrites` object that contains all writes to the specified stream
2146      made during this context.
2147    """
2148    stream.flush()
2149    fd = stream.fileno()
2150    tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
2151    tmp_file = open(tmp_file_path, "w")
2152    orig_fd = os.dup(fd)
2153    os.dup2(tmp_file.fileno(), fd)
2154    try:
2155      yield CapturedWrites(tmp_file_path)
2156    finally:
2157      tmp_file.close()
2158      os.dup2(orig_fd, fd)
2159
2160  def _AssertProtoEquals(self, a, b, msg=None):
2161    """Asserts that a and b are the same proto.
2162
2163    Uses ProtoEq() first, as it returns correct results
2164    for floating point attributes, and then use assertProtoEqual()
2165    in case of failure as it provides good error messages.
2166
2167    Args:
2168      a: a proto.
2169      b: another proto.
2170      msg: Optional message to report on failure.
2171    """
2172    if not compare.ProtoEq(a, b):
2173      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
2174
2175  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
2176    """Asserts that message is same as parsed expected_message_ascii.
2177
2178    Creates another prototype of message, reads the ascii message into it and
2179    then compares them using self._AssertProtoEqual().
2180
2181    Args:
2182      expected_message_maybe_ascii: proto message in original or ascii form.
2183      message: the message to validate.
2184      msg: Optional message to report on failure.
2185    """
2186    if isinstance(expected_message_maybe_ascii, type(message)):
2187      expected_message = expected_message_maybe_ascii
2188      self._AssertProtoEquals(expected_message, message, msg=msg)
2189    elif isinstance(expected_message_maybe_ascii, (str, bytes)):
2190      expected_message = type(message)()
2191      text_format.Merge(
2192          expected_message_maybe_ascii,
2193          expected_message,
2194          descriptor_pool=descriptor_pool.Default())
2195      self._AssertProtoEquals(expected_message, message, msg=msg)
2196    else:
2197      assert False, ("Can't compare protos of type %s and %s." %
2198                     (type(expected_message_maybe_ascii), type(message)))
2199
2200  def assertProtoEqualsVersion(
2201      self,
2202      expected,
2203      actual,
2204      producer=versions.GRAPH_DEF_VERSION,
2205      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
2206      msg=None):
2207    expected = "versions { producer: %d min_consumer: %d };\n%s" % (
2208        producer, min_consumer, expected)
2209    self.assertProtoEquals(expected, actual, msg=msg)
2210
2211  def assertStartsWith(self, actual, expected_start, msg=None):
2212    """Assert that actual.startswith(expected_start) is True.
2213
2214    Args:
2215      actual: str
2216      expected_start: str
2217      msg: Optional message to report on failure.
2218    """
2219    if not actual.startswith(expected_start):
2220      fail_msg = "%r does not start with %r" % (actual, expected_start)
2221      fail_msg += " : %r" % (msg) if msg else ""
2222      self.fail(fail_msg)
2223
2224  def _eval_tensor(self, tensor):
2225    if tensor is None:
2226      return None
2227    elif callable(tensor):
2228      return self._eval_helper(tensor())
2229    else:
2230      try:
2231        if sparse_tensor.is_sparse(tensor):
2232          return sparse_tensor.SparseTensorValue(tensor.indices.numpy(),
2233                                                 tensor.values.numpy(),
2234                                                 tensor.dense_shape.numpy())
2235        elif ragged_tensor.is_ragged(tensor):
2236          return ragged_tensor_value.RaggedTensorValue(
2237              self._eval_tensor(tensor.values),
2238              self._eval_tensor(tensor.row_splits))
2239        elif isinstance(tensor, ops.IndexedSlices):
2240          return ops.IndexedSlicesValue(
2241              values=tensor.values.numpy(),
2242              indices=tensor.indices.numpy(),
2243              dense_shape=tensor.dense_shape.numpy())
2244        # Convert tensors and composite tensors to numpy arrays.
2245        return nest.map_structure(lambda t: t.numpy(), tensor,
2246                                  expand_composites=True)
2247      except AttributeError as e:
2248        six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e)
2249
2250  def _eval_helper(self, tensors):
2251    if tensors is None:
2252      return None
2253    return nest.map_structure(self._eval_tensor, tensors)
2254
2255  def evaluate(self, tensors):
2256    """Evaluates tensors and returns numpy values.
2257
2258    Args:
2259      tensors: A Tensor or a nested list/tuple of Tensors.
2260
2261    Returns:
2262      tensors numpy values.
2263    """
2264    if context.executing_eagerly():
2265      return self._eval_helper(tensors)
2266    else:
2267      sess = ops.get_default_session()
2268      if sess is None:
2269        with self.test_session() as sess:
2270          return sess.run(tensors)
2271      else:
2272        return sess.run(tensors)
2273
2274  # pylint: disable=g-doc-return-or-yield
2275  @contextlib.contextmanager
2276  def session(self, graph=None, config=None, use_gpu=True, force_gpu=False):
2277    """A context manager for a TensorFlow Session for use in executing tests.
2278
2279    Note that this will set this session and the graph as global defaults.
2280
2281    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2282    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2283    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2284    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2285    the CPU.
2286
2287    Example:
2288
2289    ``` python
2290    class MyOperatorTest(test_util.TensorFlowTestCase):
2291      def testMyOperator(self):
2292        with self.session():
2293          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2294          result = MyOperator(valid_input).eval()
2295          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2296          invalid_input = [-1.0, 2.0, 7.0]
2297          with self.assertRaisesOpError("negative input not supported"):
2298            MyOperator(invalid_input).eval()
2299    ```
2300
2301    Args:
2302      graph: Optional graph to use during the returned session.
2303      config: An optional config_pb2.ConfigProto to use to configure the
2304        session.
2305      use_gpu: If True, attempt to run as many ops as possible on GPU.
2306      force_gpu: If True, pin all ops to `/device:GPU:0`.
2307
2308    Yields:
2309      A Session object that should be used as a context manager to surround
2310      the graph building and execution code in a test case.
2311    """
2312    if context.executing_eagerly():
2313      yield EagerSessionWarner()
2314    else:
2315      with self._create_session(graph, config, force_gpu) as sess:
2316        with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
2317          yield sess
2318
2319  @contextlib.contextmanager
2320  def cached_session(self,
2321                     graph=None,
2322                     config=None,
2323                     use_gpu=True,
2324                     force_gpu=False):
2325    """Returns a TensorFlow Session for use in executing tests.
2326
2327    This method behaves differently than self.session(): for performance reasons
2328    `cached_session` will by default reuse the same session within the same
2329    test. The session returned by this function will only be closed at the end
2330    of the test (in the TearDown function).
2331
2332    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2333    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2334    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2335    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2336    the CPU.
2337
2338    Example:
2339    ```python
2340    class MyOperatorTest(test_util.TensorFlowTestCase):
2341      def testMyOperator(self):
2342        with self.cached_session() as sess:
2343          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2344          result = MyOperator(valid_input).eval()
2345          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2346          invalid_input = [-1.0, 2.0, 7.0]
2347          with self.assertRaisesOpError("negative input not supported"):
2348            MyOperator(invalid_input).eval()
2349    ```
2350
2351    Args:
2352      graph: Optional graph to use during the returned session.
2353      config: An optional config_pb2.ConfigProto to use to configure the
2354        session.
2355      use_gpu: If True, attempt to run as many ops as possible on GPU.
2356      force_gpu: If True, pin all ops to `/device:GPU:0`.
2357
2358    Yields:
2359      A Session object that should be used as a context manager to surround
2360      the graph building and execution code in a test case.
2361    """
2362    if context.executing_eagerly():
2363      yield FakeEagerSession(self)
2364    else:
2365      sess = self._get_cached_session(
2366          graph, config, force_gpu, crash_if_inconsistent_args=True)
2367      with self._constrain_devices_and_set_default(sess, use_gpu,
2368                                                   force_gpu) as cached:
2369        yield cached
2370
2371  @contextlib.contextmanager
2372  @deprecation.deprecated(None, "Use `self.session()` or "
2373                          "`self.cached_session()` instead.")
2374  def test_session(self,
2375                   graph=None,
2376                   config=None,
2377                   use_gpu=True,
2378                   force_gpu=False):
2379    """Use cached_session instead."""
2380    if self.id().endswith(".test_session"):
2381      self.skipTest(
2382          "Tests that have the name \"test_session\" are automatically skipped "
2383          "by TensorFlow test fixture, as the name is reserved for creating "
2384          "sessions within tests. Please rename your test if you have a test "
2385          "with this name.")
2386    if context.executing_eagerly():
2387      yield None
2388    else:
2389      if graph is None:
2390        sess = self._get_cached_session(
2391            graph, config, force_gpu, crash_if_inconsistent_args=False)
2392        with self._constrain_devices_and_set_default(sess, use_gpu,
2393                                                     force_gpu) as cached:
2394          yield cached
2395      else:
2396        with self.session(graph, config, use_gpu, force_gpu) as sess:
2397          yield sess
2398
2399  # pylint: enable=g-doc-return-or-yield
2400
2401  class _CheckedThread(object):
2402    """A wrapper class for Thread that asserts successful completion.
2403
2404    This class should be created using the TensorFlowTestCase.checkedThread()
2405    method.
2406    """
2407
2408    def __init__(self, testcase, target, args=None, kwargs=None):
2409      """Constructs a new instance of _CheckedThread.
2410
2411      Args:
2412        testcase: The TensorFlowTestCase for which this thread is being created.
2413        target: A callable object representing the code to be executed in the
2414          thread.
2415        args: A tuple of positional arguments that will be passed to target.
2416        kwargs: A dictionary of keyword arguments that will be passed to target.
2417      """
2418      self._testcase = testcase
2419      self._target = target
2420      self._args = () if args is None else args
2421      self._kwargs = {} if kwargs is None else kwargs
2422      self._thread = threading.Thread(target=self._protected_run)
2423      self._exception = None
2424
2425      self._is_thread_joined = False
2426
2427    def _protected_run(self):
2428      """Target for the wrapper thread. Sets self._exception on failure."""
2429      try:
2430        self._target(*self._args, **self._kwargs)
2431      except Exception as e:  # pylint: disable=broad-except
2432        self._exception = e
2433
2434    def start(self):
2435      """Starts the thread's activity.
2436
2437      This must be called at most once per _CheckedThread object. It arranges
2438      for the object's target to be invoked in a separate thread of control.
2439      """
2440      self._thread.start()
2441
2442    def join(self):
2443      """Blocks until the thread terminates.
2444
2445      Raises:
2446        self._testcase.failureException: If the thread terminates with due to
2447          an exception.
2448      """
2449      self._is_thread_joined = True
2450      self._thread.join()
2451      if self._exception is not None:
2452        self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
2453
2454    def is_alive(self):
2455      """Returns whether the thread is alive.
2456
2457      This method returns True just before the run() method starts
2458      until just after the run() method terminates.
2459
2460      Returns:
2461        True if the thread is alive, otherwise False.
2462      """
2463      return self._thread.is_alive()
2464
2465    def check_termination(self):
2466      """Returns whether the checked thread was properly used and did terminate.
2467
2468      Every checked thread should be "join"ed after starting, and before the
2469      test tears down. If it is not joined, it is possible the thread will hang
2470      and cause flaky failures in tests.
2471
2472      Raises:
2473        self._testcase.failureException: If check_termination was called before
2474        thread was joined.
2475
2476        RuntimeError: If the thread is not terminated. This means thread was not
2477        joined with the main thread.
2478      """
2479      if self._is_thread_joined:
2480        if self.is_alive():
2481          raise RuntimeError(
2482              "Thread was not joined with main thread, and is still running "
2483              "when the test finished.")
2484      else:
2485        self._testcase.fail("A checked thread was not joined.")
2486
2487  def checkedThread(self, target, args=None, kwargs=None):
2488    """Returns a Thread wrapper that asserts 'target' completes successfully.
2489
2490    This method should be used to create all threads in test cases, as
2491    otherwise there is a risk that a thread will silently fail, and/or
2492    assertions made in the thread will not be respected.
2493
2494    Args:
2495      target: A callable object to be executed in the thread.
2496      args: The argument tuple for the target invocation. Defaults to ().
2497      kwargs: A dictionary of keyword arguments for the target invocation.
2498        Defaults to {}.
2499
2500    Returns:
2501      A wrapper for threading.Thread that supports start() and join() methods.
2502    """
2503    ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
2504    self._threads.append(ret)
2505    return ret
2506
2507  # pylint: enable=invalid-name
2508  @py_func_if_in_function
2509  def assertNear(self, f1, f2, err, msg=None):
2510    """Asserts that two floats are near each other.
2511
2512    Checks that |f1 - f2| < err and asserts a test failure
2513    if not.
2514
2515    Args:
2516      f1: A float value.
2517      f2: A float value.
2518      err: A float value.
2519      msg: An optional string message to append to the failure message.
2520    """
2521    # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
2522    self.assertTrue(
2523        f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
2524        (f1, f2, err, " (%s)" % msg if msg is not None else ""))
2525
2526  @py_func_if_in_function
2527  def assertArrayNear(self, farray1, farray2, err, msg=None):
2528    """Asserts that two float arrays are near each other.
2529
2530    Checks that for all elements of farray1 and farray2
2531    |f1 - f2| < err.  Asserts a test failure if not.
2532
2533    Args:
2534      farray1: a list of float values.
2535      farray2: a list of float values.
2536      err: a float value.
2537      msg: Optional message to report on failure.
2538    """
2539    self.assertEqual(len(farray1), len(farray2), msg=msg)
2540    for f1, f2 in zip(farray1, farray2):
2541      self.assertNear(float(f1), float(f2), err, msg=msg)
2542
2543  def _NDArrayNear(self, ndarray1, ndarray2, err):
2544    return np.linalg.norm(ndarray1 - ndarray2) < err
2545
2546  @py_func_if_in_function
2547  def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
2548    """Asserts that two numpy arrays have near values.
2549
2550    Args:
2551      ndarray1: a numpy ndarray.
2552      ndarray2: a numpy ndarray.
2553      err: a float. The maximum absolute difference allowed.
2554      msg: Optional message to report on failure.
2555    """
2556    self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
2557
2558  def _GetNdArray(self, a):
2559    # If a is tensor-like then convert it to ndarray
2560    if tensor_util.is_tf_type(a):
2561      if isinstance(a, ops._EagerTensorBase):
2562        a = a.numpy()
2563      else:
2564        a = self.evaluate(a)
2565    if not isinstance(a, np.ndarray):
2566      return np.array(a)
2567    return a
2568
2569  def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2570    a = self._GetNdArray(a)
2571    b = self._GetNdArray(b)
2572    # When the array rank is small, print its contents. Numpy array printing is
2573    # implemented using inefficient recursion so prints can cause tests to
2574    # time out.
2575    if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
2576      shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
2577                            "%s.") % (a.shape, b.shape, b)
2578    else:
2579      shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
2580                                                                     b.shape)
2581    self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
2582
2583    msgs = [msg]
2584    # np.allclose does not always work for our custom bfloat16 extension type
2585    # when type promotions are involved, so we first cast any bfloat16 arrays
2586    # to float32.
2587    a_dtype = a.dtype
2588    a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
2589    b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
2590    if not np.allclose(a, b, rtol=rtol, atol=atol):
2591      # Adds more details to np.testing.assert_allclose.
2592      #
2593      # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
2594      # checks whether two arrays are element-wise equal within a
2595      # tolerance. The relative difference (rtol * abs(b)) and the
2596      # absolute difference atol are added together to compare against
2597      # the absolute difference between a and b.  Here, we want to
2598      # tell user which elements violate such conditions.
2599      cond = np.logical_or(
2600          np.abs(a - b) > atol + rtol * np.abs(b),
2601          np.isnan(a) != np.isnan(b))
2602      if a.ndim:
2603        x = a[np.where(cond)]
2604        y = b[np.where(cond)]
2605        msgs.append("not close where = {}".format(np.where(cond)))
2606      else:
2607        # np.where is broken for scalars
2608        x, y = a, b
2609      msgs.append("not close lhs = {}".format(x))
2610      msgs.append("not close rhs = {}".format(y))
2611      msgs.append("not close dif = {}".format(np.abs(x - y)))
2612      msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
2613      msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
2614      # TODO(xpan): There seems to be a bug:
2615      # tensorflow/compiler/tests:binary_ops_test pass with float32
2616      # nan even though the equal_nan is False by default internally.
2617      np.testing.assert_allclose(
2618          a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
2619
2620  def _assertAllCloseRecursive(self,
2621                               a,
2622                               b,
2623                               rtol=1e-6,
2624                               atol=1e-6,
2625                               path=None,
2626                               msg=None):
2627    path = path or []
2628    path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "")
2629    msg = msg if msg else ""
2630
2631    # Check if a and/or b are namedtuples.
2632    if hasattr(a, "_asdict"):
2633      a = a._asdict()
2634    if hasattr(b, "_asdict"):
2635      b = b._asdict()
2636    a_is_dict = isinstance(a, collections_abc.Mapping)
2637    if a_is_dict != isinstance(b, collections_abc.Mapping):
2638      raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
2639                       (path_str, path_str, msg))
2640    if a_is_dict:
2641      self.assertItemsEqual(
2642          a.keys(),
2643          b.keys(),
2644          msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
2645          (path_str, a.keys(), path_str, b.keys(), msg))
2646      for k in a:
2647        path.append(k)
2648        self._assertAllCloseRecursive(
2649            a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
2650        del path[-1]
2651    elif isinstance(a, (list, tuple)):
2652      # Try to directly compare a, b as ndarrays; if not work, then traverse
2653      # through the sequence, which is more expensive.
2654      try:
2655        a_as_ndarray = self._GetNdArray(a)
2656        b_as_ndarray = self._GetNdArray(b)
2657        self._assertArrayLikeAllClose(
2658            a_as_ndarray,
2659            b_as_ndarray,
2660            rtol=rtol,
2661            atol=atol,
2662            msg="Mismatched value: a%s is different from b%s. %s" %
2663            (path_str, path_str, msg))
2664      except (ValueError, TypeError) as e:
2665        if len(a) != len(b):
2666          raise ValueError(
2667              "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
2668              (path_str, len(a), path_str, len(b), msg))
2669        for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
2670          path.append(str(idx))
2671          self._assertAllCloseRecursive(
2672              a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
2673          del path[-1]
2674    # a and b are ndarray like objects
2675    else:
2676      try:
2677        self._assertArrayLikeAllClose(
2678            a,
2679            b,
2680            rtol=rtol,
2681            atol=atol,
2682            msg=("Mismatched value: a%s is different from b%s. %s" %
2683                 (path_str, path_str, msg)))
2684      except TypeError as e:
2685        msg = ("Error: a%s has %s, but b%s has %s. %s" %
2686               (path_str, type(a), path_str, type(b), msg))
2687        e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
2688        raise
2689
2690  @py_func_if_in_function
2691  def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2692    """Asserts that two structures of numpy arrays or Tensors, have near values.
2693
2694    `a` and `b` can be arbitrarily nested structures. A layer of a nested
2695    structure can be a `dict`, `namedtuple`, `tuple` or `list`.
2696
2697    Note: the implementation follows
2698    [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html)
2699    (and numpy.testing.assert_allclose). It checks whether two arrays are
2700    element-wise equal within a tolerance. The relative difference
2701    (`rtol * abs(b)`) and the absolute difference `atol` are added together
2702    to compare against the absolute difference between `a` and `b`.
2703
2704    Args:
2705      a: The expected numpy `ndarray`, or anything that can be converted into a
2706        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2707        structure of these.
2708      b: The actual numpy `ndarray`, or anything that can be converted into a
2709        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2710        structure of these.
2711      rtol: relative tolerance.
2712      atol: absolute tolerance.
2713      msg: Optional message to report on failure.
2714
2715    Raises:
2716      ValueError: if only one of `a[p]` and `b[p]` is a dict or
2717          `a[p]` and `b[p]` have different length, where `[p]` denotes a path
2718          to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
2719          `[p] = [1]['d']`, then `a[p] = (6, 7)`.
2720    """
2721    if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
2722      return self._assertRaggedClose(a, b, rtol, atol, msg)
2723    self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
2724
2725  @py_func_if_in_function
2726  def assertAllCloseAccordingToType(self,
2727                                    a,
2728                                    b,
2729                                    rtol=1e-6,
2730                                    atol=1e-6,
2731                                    float_rtol=1e-6,
2732                                    float_atol=1e-6,
2733                                    half_rtol=1e-3,
2734                                    half_atol=1e-3,
2735                                    bfloat16_rtol=1e-2,
2736                                    bfloat16_atol=1e-2,
2737                                    msg=None):
2738    """Like assertAllClose, but also suitable for comparing fp16 arrays.
2739
2740    In particular, the tolerance is reduced to 1e-3 if at least
2741    one of the arguments is of type float16.
2742
2743    Args:
2744      a: the expected numpy ndarray or anything can be converted to one.
2745      b: the actual numpy ndarray or anything can be converted to one.
2746      rtol: relative tolerance.
2747      atol: absolute tolerance.
2748      float_rtol: relative tolerance for float32.
2749      float_atol: absolute tolerance for float32.
2750      half_rtol: relative tolerance for float16.
2751      half_atol: absolute tolerance for float16.
2752      bfloat16_rtol: relative tolerance for bfloat16.
2753      bfloat16_atol: absolute tolerance for bfloat16.
2754      msg: Optional message to report on failure.
2755    """
2756    a = self._GetNdArray(a)
2757    b = self._GetNdArray(b)
2758    # types with lower tol are put later to overwrite previous ones.
2759    if (a.dtype == np.float32 or b.dtype == np.float32 or
2760        a.dtype == np.complex64 or b.dtype == np.complex64):
2761      rtol = max(rtol, float_rtol)
2762      atol = max(atol, float_atol)
2763    if a.dtype == np.float16 or b.dtype == np.float16:
2764      rtol = max(rtol, half_rtol)
2765      atol = max(atol, half_atol)
2766    if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
2767        b.dtype == dtypes.bfloat16.as_numpy_dtype):
2768      rtol = max(rtol, bfloat16_rtol)
2769      atol = max(atol, bfloat16_atol)
2770
2771    self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
2772
2773  @py_func_if_in_function
2774  def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2775    """Assert that two numpy arrays, or Tensors, do not have near values.
2776
2777    Args:
2778      a: The expected numpy `ndarray`, or anything that can be converted into a
2779        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2780        structure of these.
2781      b: The actual numpy `ndarray`, or anything that can be converted into a
2782        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2783        structure of these.
2784      rtol: relative tolerance.
2785      atol: absolute tolerance.
2786      msg: Optional message to report on failure.
2787
2788    Raises:
2789      AssertionError: If `a` and `b` are unexpectedly close at all elements.
2790    """
2791    try:
2792      self.assertAllClose(a, b,  rtol=rtol, atol=atol, msg=msg)
2793    except AssertionError:
2794      return
2795    msg = msg or ""
2796    raise AssertionError("The two values are close at all elements. %s" % msg)
2797
2798  @py_func_if_in_function
2799  def assertAllEqual(self, a, b, msg=None):
2800    """Asserts that two numpy arrays or Tensors have the same values.
2801
2802    Args:
2803      a: the expected numpy ndarray or anything can be converted to one.
2804      b: the actual numpy ndarray or anything can be converted to one.
2805      msg: Optional message to report on failure.
2806    """
2807    if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
2808      return self._assertRaggedEqual(a, b, msg)
2809    msg = msg if msg else ""
2810    a = self._GetNdArray(a)
2811    b = self._GetNdArray(b)
2812    # Arbitrary bounds so that we don't print giant tensors.
2813    if (b.ndim <= 3 or b.size < 500):
2814      self.assertEqual(
2815          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2816          " Contents: %r. \n%s." % (a.shape, b.shape, b, msg))
2817    else:
2818      self.assertEqual(
2819          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2820          " %s" % (a.shape, b.shape, msg))
2821
2822    same = (a == b)
2823
2824    if (a.dtype in [
2825        np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
2826    ]):
2827      same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
2828    msgs = [msg]
2829    if not np.all(same):
2830      # Adds more details to np.testing.assert_array_equal.
2831      diff = np.logical_not(same)
2832      if a.ndim:
2833        x = a[np.where(diff)]
2834        y = b[np.where(diff)]
2835        msgs.append("not equal where = {}".format(np.where(diff)))
2836      else:
2837        # np.where is broken for scalars
2838        x, y = a, b
2839      msgs.append("not equal lhs = %r" % x)
2840      msgs.append("not equal rhs = %r" % y)
2841
2842      # Handle mixed string types as a result of PY2to3 migration. That is, the
2843      # mixing between bytes (b-prefix strings, PY2 default) and unicodes
2844      # (u-prefix strings, PY3 default).
2845      if six.PY3:
2846        if (a.dtype.kind != b.dtype.kind and
2847            {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
2848          a_list = []
2849          b_list = []
2850          # OK to flatten `a` and `b` because they are guaranteed to have the
2851          # same shape.
2852          for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
2853            for item in flat_arr:
2854              if isinstance(item, str):
2855                out_list.append(item.encode("utf-8"))
2856              else:
2857                out_list.append(item)
2858          a = np.array(a_list)
2859          b = np.array(b_list)
2860
2861      np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
2862
2863  @py_func_if_in_function
2864  def assertNotAllEqual(self, a, b, msg=None):
2865    """Asserts that two numpy arrays or Tensors do not have the same values.
2866
2867    Args:
2868      a: the expected numpy ndarray or anything can be converted to one.
2869      b: the actual numpy ndarray or anything can be converted to one.
2870      msg: Optional message to report on failure.
2871    """
2872    try:
2873      self.assertAllEqual(a, b)
2874    except AssertionError:
2875      return
2876    raise AssertionError("The two values are equal at all elements. %s" % msg)
2877
2878  @py_func_if_in_function
2879  def assertAllGreater(self, a, comparison_target):
2880    """Assert element values are all greater than a target value.
2881
2882    Args:
2883      a: The numpy `ndarray`, or anything that can be converted into a numpy
2884        `ndarray` (including Tensor).
2885      comparison_target: The target value of comparison.
2886    """
2887    a = self._GetNdArray(a)
2888    self.assertGreater(np.min(a), comparison_target)
2889
2890  @py_func_if_in_function
2891  def assertAllLess(self, a, comparison_target):
2892    """Assert element values are all less than a target value.
2893
2894    Args:
2895      a: The numpy `ndarray`, or anything that can be converted into a numpy
2896        `ndarray` (including Tensor).
2897      comparison_target: The target value of comparison.
2898    """
2899    a = self._GetNdArray(a)
2900    self.assertLess(np.max(a), comparison_target)
2901
2902  @py_func_if_in_function
2903  def assertAllGreaterEqual(self, a, comparison_target):
2904    """Assert element values are all greater than or equal to a target value.
2905
2906    Args:
2907      a: The numpy `ndarray`, or anything that can be converted into a numpy
2908        `ndarray` (including Tensor).
2909      comparison_target: The target value of comparison.
2910    """
2911    a = self._GetNdArray(a)
2912    self.assertGreaterEqual(np.min(a), comparison_target)
2913
2914  @py_func_if_in_function
2915  def assertAllLessEqual(self, a, comparison_target):
2916    """Assert element values are all less than or equal to a target value.
2917
2918    Args:
2919      a: The numpy `ndarray`, or anything that can be converted into a numpy
2920        `ndarray` (including Tensor).
2921      comparison_target: The target value of comparison.
2922    """
2923    a = self._GetNdArray(a)
2924    self.assertLessEqual(np.max(a), comparison_target)
2925
2926  def _format_subscripts(self, subscripts, value, limit=10, indent=2):
2927    """Generate a summary of ndarray subscripts as a list of str.
2928
2929    If limit == N, this method will print up to the first N subscripts on
2930    separate
2931    lines. A line of ellipses (...) will be appended at the end if the number of
2932    subscripts exceeds N.
2933
2934    Args:
2935      subscripts: The tensor (np.ndarray) subscripts, of the same format as
2936        np.where()'s return value, i.e., a tuple of arrays with each array
2937        corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
2938      value: (np.ndarray) value of the tensor.
2939      limit: (int) The maximum number of indices to print.
2940      indent: (int) Number of characters to indent at the beginning of each
2941        line.
2942
2943    Returns:
2944      (list of str) the multi-line representation of the subscripts and values,
2945        potentially with omission at the end.
2946    """
2947    lines = []
2948    subscripts = np.transpose(subscripts)
2949    prefix = " " * indent
2950    if np.ndim(value) == 0:
2951      return [prefix + "[0] : " + str(value)]
2952    for subscript in itertools.islice(subscripts, limit):
2953      lines.append(prefix + str(subscript) + " : " +
2954                   str(value[tuple(subscript)]))
2955    if len(subscripts) > limit:
2956      lines.append(prefix + "...")
2957    return lines
2958
2959  @py_func_if_in_function
2960  def assertAllInRange(self,
2961                       target,
2962                       lower_bound,
2963                       upper_bound,
2964                       open_lower_bound=False,
2965                       open_upper_bound=False):
2966    """Assert that elements in a Tensor are all in a given range.
2967
2968    Args:
2969      target: The numpy `ndarray`, or anything that can be converted into a
2970        numpy `ndarray` (including Tensor).
2971      lower_bound: lower bound of the range
2972      upper_bound: upper bound of the range
2973      open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
2974        than the default >=)
2975      open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
2976        than the default <=)
2977
2978    Raises:
2979      AssertionError:
2980        if the value tensor does not have an ordered numeric type (float* or
2981          int*), or
2982        if there are nan values, or
2983        if any of the elements do not fall in the specified range.
2984    """
2985    target = self._GetNdArray(target)
2986    if not (np.issubdtype(target.dtype, np.floating) or
2987            np.issubdtype(target.dtype, np.integer)):
2988      raise AssertionError(
2989          "The value of %s does not have an ordered numeric type, instead it "
2990          "has type: %s" % (target, target.dtype))
2991
2992    nan_subscripts = np.where(np.isnan(target))
2993    if np.size(nan_subscripts):
2994      raise AssertionError(
2995          "%d of the %d element(s) are NaN. "
2996          "Subscripts(s) and value(s) of the NaN element(s):\n" %
2997          (len(nan_subscripts[0]), np.size(target)) +
2998          "\n".join(self._format_subscripts(nan_subscripts, target)))
2999
3000    range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
3001                 str(upper_bound) + (")" if open_upper_bound else "]"))
3002
3003    violations = (
3004        np.less_equal(target, lower_bound) if open_lower_bound else np.less(
3005            target, lower_bound))
3006    violations = np.logical_or(
3007        violations,
3008        np.greater_equal(target, upper_bound)
3009        if open_upper_bound else np.greater(target, upper_bound))
3010    violation_subscripts = np.where(violations)
3011    if np.size(violation_subscripts):
3012      raise AssertionError(
3013          "%d of the %d element(s) are outside the range %s. " %
3014          (len(violation_subscripts[0]), np.size(target), range_str) +
3015          "Subscript(s) and value(s) of the offending elements:\n" +
3016          "\n".join(self._format_subscripts(violation_subscripts, target)))
3017
3018  @py_func_if_in_function
3019  def assertAllInSet(self, target, expected_set):
3020    """Assert that elements of a Tensor are all in a given closed set.
3021
3022    Args:
3023      target: The numpy `ndarray`, or anything that can be converted into a
3024        numpy `ndarray` (including Tensor).
3025      expected_set: (`list`, `tuple` or `set`) The closed set that the elements
3026        of the value of `target` are expected to fall into.
3027
3028    Raises:
3029      AssertionError:
3030        if any of the elements do not fall into `expected_set`.
3031    """
3032    target = self._GetNdArray(target)
3033
3034    # Elements in target that are not in expected_set.
3035    diff = np.setdiff1d(target.flatten(), list(expected_set))
3036    if np.size(diff):
3037      raise AssertionError("%d unique element(s) are not in the set %s: %s" %
3038                           (np.size(diff), expected_set, diff))
3039
3040  @py_func_if_in_function
3041  def assertDTypeEqual(self, target, expected_dtype):
3042    """Assert ndarray data type is equal to expected.
3043
3044    Args:
3045      target: The numpy `ndarray`, or anything that can be converted into a
3046        numpy `ndarray` (including Tensor).
3047      expected_dtype: Expected data type.
3048    """
3049    target = self._GetNdArray(target)
3050    if not isinstance(target, list):
3051      arrays = [target]
3052    for arr in arrays:
3053      self.assertEqual(arr.dtype, expected_dtype)
3054
3055  # pylint: disable=g-doc-return-or-yield
3056  @contextlib.contextmanager
3057  def assertRaisesWithPredicateMatch(self, exception_type,
3058                                     expected_err_re_or_predicate):
3059    """Returns a context manager to enclose code expected to raise an exception.
3060
3061    If the exception is an OpError, the op stack is also included in the message
3062    predicate search.
3063
3064    Args:
3065      exception_type: The expected type of exception that should be raised.
3066      expected_err_re_or_predicate: If this is callable, it should be a function
3067        of one argument that inspects the passed-in exception and returns True
3068        (success) or False (please fail the test). Otherwise, the error message
3069        is expected to match this regular expression partially.
3070
3071    Returns:
3072      A context manager to surround code that is expected to raise an
3073      exception.
3074    """
3075    if callable(expected_err_re_or_predicate):
3076      predicate = expected_err_re_or_predicate
3077    else:
3078
3079      def predicate(e):
3080        err_str = e.message if isinstance(e, errors.OpError) else str(e)
3081        op = e.op if isinstance(e, errors.OpError) else None
3082        while op is not None:
3083          err_str += "\nCaused by: " + op.name
3084          op = op._original_op  # pylint: disable=protected-access
3085        logging.info("Searching within error strings: '%s' within '%s'",
3086                     expected_err_re_or_predicate, err_str)
3087        return re.search(expected_err_re_or_predicate, err_str)
3088
3089    try:
3090      yield
3091      self.fail(exception_type.__name__ + " not raised")
3092    except Exception as e:  # pylint: disable=broad-except
3093      if not isinstance(e, exception_type) or not predicate(e):
3094        raise AssertionError("Exception of type %s: %s" %
3095                             (str(type(e)), str(e)))
3096
3097  # pylint: enable=g-doc-return-or-yield
3098
3099  def assertRaisesOpError(self, expected_err_re_or_predicate):
3100    return self.assertRaisesWithPredicateMatch(errors.OpError,
3101                                               expected_err_re_or_predicate)
3102
3103  def assertShapeEqual(self, np_array, tf_tensor, msg=None):
3104    """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
3105
3106    Args:
3107      np_array: A Numpy ndarray or Numpy scalar.
3108      tf_tensor: A Tensor.
3109      msg: Optional message to report on failure.
3110
3111    Raises:
3112      TypeError: If the arguments have the wrong type.
3113    """
3114    if not isinstance(np_array, (np.ndarray, np.generic)):
3115      raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
3116    if not isinstance(tf_tensor, ops.Tensor):
3117      raise TypeError("tf_tensor must be a Tensor")
3118    self.assertAllEqual(
3119        np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
3120
3121  def assertDeviceEqual(self, device1, device2, msg=None):
3122    """Asserts that the two given devices are the same.
3123
3124    Args:
3125      device1: A string device name or TensorFlow `DeviceSpec` object.
3126      device2: A string device name or TensorFlow `DeviceSpec` object.
3127      msg: Optional message to report on failure.
3128    """
3129    device1 = pydev.canonical_name(device1)
3130    device2 = pydev.canonical_name(device2)
3131    self.assertEqual(
3132        device1, device2,
3133        "Devices %s and %s are not equal. %s" % (device1, device2, msg))
3134
3135  def _GetPyList(self, a):
3136    """Converts `a` to a nested python list."""
3137    if isinstance(a, ragged_tensor.RaggedTensor):
3138      return self.evaluate(a).to_list()
3139    elif isinstance(a, ops.Tensor):
3140      a = self.evaluate(a)
3141      return a.tolist() if isinstance(a, np.ndarray) else a
3142    elif isinstance(a, np.ndarray):
3143      return a.tolist()
3144    elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
3145      return a.to_list()
3146    else:
3147      return np.array(a).tolist()
3148
3149  def _assertRaggedEqual(self, a, b, msg):
3150    """Asserts that two ragged tensors are equal."""
3151    a_list = self._GetPyList(a)
3152    b_list = self._GetPyList(b)
3153    self.assertEqual(a_list, b_list, msg)
3154
3155    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3156      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3157      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3158      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3159
3160  def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
3161    a_list = self._GetPyList(a)
3162    b_list = self._GetPyList(b)
3163    self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
3164
3165    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3166      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3167      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3168      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3169
3170  def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
3171    self.assertEqual(type(a), type(b))
3172    if isinstance(a, (list, tuple)):
3173      self.assertLen(a, len(b), "Length differs for %s" % path)
3174      for i in range(len(a)):
3175        self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
3176                                       "%s[%s]" % (path, i))
3177    else:
3178      self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
3179
3180  # Fix Python 3+ compatibility issues
3181  if not six.PY2:
3182    # pylint: disable=invalid-name
3183
3184    # Silence a deprecation warning
3185    assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
3186
3187    # assertItemsEqual is assertCountEqual as of 3.2.
3188    assertItemsEqual = googletest.TestCase.assertCountEqual
3189
3190    # pylint: enable=invalid-name
3191
3192  @contextlib.contextmanager
3193  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
3194    """Set the session and its graph to global default and constrain devices."""
3195    if context.executing_eagerly():
3196      yield None
3197    else:
3198      with sess.graph.as_default(), sess.as_default():
3199        if force_gpu:
3200          # Use the name of an actual device if one is detected, or
3201          # '/device:GPU:0' otherwise
3202          gpu_name = gpu_device_name()
3203          if not gpu_name:
3204            gpu_name = "/device:GPU:0"
3205          with sess.graph.device(gpu_name):
3206            yield sess
3207        elif use_gpu:
3208          yield sess
3209        else:
3210          with sess.graph.device("/device:CPU:0"):
3211            yield sess
3212
3213  def _create_session(self, graph, config, force_gpu):
3214    """See session() for details."""
3215
3216    def prepare_config(config):
3217      """Returns a config for sessions.
3218
3219      Args:
3220        config: An optional config_pb2.ConfigProto to use to configure the
3221          session.
3222
3223      Returns:
3224        A config_pb2.ConfigProto object.
3225      """
3226      # TODO(b/114333779): Enforce allow_soft_placement=False when
3227      # use_gpu=False. Currently many tests rely on the fact that any device
3228      # will be used even when a specific device is supposed to be used.
3229      allow_soft_placement = not force_gpu
3230      if config is None:
3231        config = context.context().config
3232        config.allow_soft_placement = allow_soft_placement
3233      elif not allow_soft_placement and config.allow_soft_placement:
3234        config_copy = context.context().config
3235        config = config_copy
3236        config.allow_soft_placement = False
3237      # Don't perform optimizations for tests so we don't inadvertently run
3238      # gpu ops on cpu
3239      config.graph_options.optimizer_options.opt_level = -1
3240      # Disable Grappler constant folding since some tests & benchmarks
3241      # use constant input and become meaningless after constant folding.
3242      # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
3243      # GRAPPLER TEAM.
3244      config.graph_options.rewrite_options.constant_folding = (
3245          rewriter_config_pb2.RewriterConfig.OFF)
3246      config.graph_options.rewrite_options.pin_to_host_optimization = (
3247          rewriter_config_pb2.RewriterConfig.OFF)
3248      return config
3249
3250    return ErrorLoggingSession(graph=graph, config=prepare_config(config))
3251
3252  def _get_cached_session(self,
3253                          graph=None,
3254                          config=None,
3255                          force_gpu=False,
3256                          crash_if_inconsistent_args=True):
3257    """See cached_session() for documentation."""
3258    if self._cached_session is None:
3259      sess = self._create_session(
3260          graph=graph, config=config, force_gpu=force_gpu)
3261      self._cached_session = sess
3262      self._cached_graph = graph
3263      self._cached_config = config
3264      self._cached_force_gpu = force_gpu
3265      return sess
3266    else:
3267      if crash_if_inconsistent_args and self._cached_graph is not graph:
3268        raise ValueError("The graph used to get the cached session is "
3269                         "different than the one that was used to create the "
3270                         "session. Maybe create a new session with "
3271                         "self.session()")
3272      if crash_if_inconsistent_args and self._cached_config is not config:
3273        raise ValueError("The config used to get the cached session is "
3274                         "different than the one that was used to create the "
3275                         "session. Maybe create a new session with "
3276                         "self.session()")
3277      if crash_if_inconsistent_args and (self._cached_force_gpu is
3278                                         not force_gpu):
3279        raise ValueError(
3280            "The force_gpu value used to get the cached session is "
3281            "different than the one that was used to create the "
3282            "session. Maybe create a new session with "
3283            "self.session()")
3284      return self._cached_session
3285
3286
3287@tf_export("test.create_local_cluster")
3288def create_local_cluster(num_workers,
3289                         num_ps,
3290                         protocol="grpc",
3291                         worker_config=None,
3292                         ps_config=None):
3293  """Create and start local servers and return the associated `Server` objects.
3294
3295  "PS" stands for "parameter server": a task responsible for storing and
3296  updating the model's parameters. Other tasks send updates to these parameters
3297  as they work on optimizing the parameters. This particular division of labor
3298  between tasks is not required, but is common for distributed training.
3299
3300  Read more at https://www.tensorflow.org/guide/extend/architecture
3301
3302  ![components](https://www.tensorflow.org/images/diag1.svg "components")
3303
3304
3305  Figure illustrates the interaction of these components.
3306  "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services.
3307
3308
3309  Example:
3310  ```python
3311  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
3312
3313  worker_sessions = [tf.compat.v1.Session(w.target) for w in workers]
3314
3315  with tf.device("/job:ps/task:0"):
3316    ...
3317  with tf.device("/job:ps/task:1"):
3318    ...
3319  with tf.device("/job:worker/task:0"):
3320    ...
3321  with tf.device("/job:worker/task:1"):
3322    ...
3323
3324  worker_sessions[0].run(...)
3325  ```
3326
3327  Args:
3328    num_workers: Number of worker servers to start.
3329    num_ps: Number of PS servers to start.
3330    protocol: Communication protocol. Allowed values are documented in the
3331      documentation of `tf.distribute.Server`.
3332    worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be
3333      used to instantiate multiple devices etc.
3334    ps_config: (optional) `tf.ConfigProto` to initialize PS servers.
3335
3336  Returns:
3337    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
3338    of `num_workers` objects of type `tf.distribute.Server` (all running
3339    locally);
3340    and `ps_servers` is a list of `num_ps` objects of similar type.
3341
3342  Raises:
3343    ImportError: if portpicker module was not found at load time
3344  """
3345  import portpicker  # pylint: disable=g-import-not-at-top
3346  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
3347  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
3348  cluster_dict = {
3349      "worker": ["localhost:%s" % port for port in worker_ports],
3350      "ps": ["localhost:%s" % port for port in ps_ports]
3351  }
3352  cs = server_lib.ClusterSpec(cluster_dict)
3353
3354  workers = [
3355      server_lib.Server(
3356          cs,
3357          job_name="worker",
3358          protocol=protocol,
3359          task_index=ix,
3360          config=worker_config,
3361          start=True) for ix in range(num_workers)
3362  ]
3363  ps_servers = [
3364      server_lib.Server(
3365          cs,
3366          job_name="ps",
3367          protocol=protocol,
3368          task_index=ix,
3369          config=ps_config,
3370          start=True) for ix in range(num_ps)
3371  ]
3372
3373  return workers, ps_servers
3374
3375
3376def get_node_def_from_graph(node_name, graph_def):
3377  """Returns the `NodeDef` instance for given node name in the graph def.
3378
3379  This method explores only the NodeDefs in `graph_def.node`.
3380
3381  Args:
3382    node_name: Name of the NodeDef to search for.
3383    graph_def: An instance of `GraphDef` proto.
3384
3385  Returns:
3386    the `NodeDef` instance whose name field matches the given node_name or None.
3387  """
3388  for node_def in graph_def.node:
3389    if node_def.name == node_name:
3390      return node_def
3391  return None
3392
3393
3394def set_producer_version(graph, producer_version):
3395  """Sets graph.graph_def_versions.producer to `producer_version`."""
3396  # The C API doesn't expose altering GraphDefVersions. We can indirectly set
3397  # it via import_graph_def though.
3398  graph_def = graph_pb2.GraphDef()
3399  graph_def.versions.producer = producer_version
3400  with graph.as_default():
3401    importer.import_graph_def(graph_def)
3402  assert graph.graph_def_versions.producer, producer_version
3403
3404
3405@contextlib.contextmanager
3406def _fake_gradient_tape_context_manager():
3407  """tf.gradients(...) implemented as tf.GradientTape context manager interface.
3408
3409  This is useful to test tf.gradients() in tests that uses tf.GradientTape().
3410
3411  Yields:
3412    gradient tape instance that's implemented by tf.gradients() underneath.
3413  """
3414  try:
3415    class FakeGradientTape:
3416
3417      def watch(self, x):
3418        pass
3419
3420      def gradient(self, y, x, grad_ys=None):
3421        result = gradients_impl.gradients(y, x, grad_ys)
3422
3423        # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single
3424        # element. So unpack if needed to match `tape.gradient()` behavior.
3425        if not isinstance(x, (list, tuple)):
3426          assert len(result) == 1
3427          return result[0]
3428
3429        return result
3430
3431    yield FakeGradientTape()
3432  finally:
3433    pass
3434
3435
3436class AbstractGradientTape:
3437  """Abstract GradientTape context manager that has multiple implementations.
3438
3439  This is useful to test both tf.GradientTape() and tf.gradients() without
3440  duplicating tests.
3441  """
3442
3443  def __init__(self, use_tape, persistent=False):
3444    self._use_tape = use_tape
3445    self._persistent = persistent
3446
3447  def __enter__(self):
3448    if self._use_tape:
3449      self._tape_impl = backprop.GradientTape(persistent=self._persistent)
3450    else:
3451      self._tape_impl = _fake_gradient_tape_context_manager()
3452    return self._tape_impl.__enter__()
3453
3454  def __exit__(self, exc_type, exc_val, exc_tb):
3455    self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
3456
3457
3458@contextlib.contextmanager
3459def run_functions_eagerly(run_eagerly):
3460  """Runs functions eagerly if `run_eagerly` is true.
3461
3462  WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode
3463  *WILL NOT* make the tf.function to run eagerly because eager is disabled by
3464  default in V1. Instead, tf.function will run as a traced graph function.
3465
3466  Ensures that the state (for running functions eagerly) is back to the initial
3467  `def_function.RUN_FUNCTIONS_EAGERLY` state.
3468
3469  Args:
3470    run_eagerly: Boolean determining whether to run the function eagerly or not.
3471
3472  Raises:
3473    ValueError if `run_eagerly` is not a boolean.
3474
3475  Yields:
3476    Nothing.
3477  """
3478  if not isinstance(run_eagerly, bool):
3479    raise ValueError(
3480        "Expected bool for `run_eagerly` but got {}".format(run_eagerly))
3481
3482  is_eager = context.executing_eagerly()
3483  if not is_eager and run_eagerly:
3484    logging.warning(
3485        "Running tf.function eagerly in V1 graph mode is not supported. "
3486        "tf.function will be run as a traced graph function.")
3487
3488  initial_state = def_function.functions_run_eagerly()
3489  def_function.run_functions_eagerly(run_eagerly)
3490  try:
3491    yield
3492  finally:
3493    def_function.run_functions_eagerly(initial_state)
3494