• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23from collections import OrderedDict
24import contextlib
25import functools
26import gc
27import itertools
28import math
29import os
30import random
31import re
32import tempfile
33import threading
34import unittest
35
36from absl.testing import parameterized
37import numpy as np
38import six
39
40from google.protobuf import descriptor_pool
41from google.protobuf import text_format
42
43from tensorflow.core.framework import graph_pb2
44from tensorflow.core.protobuf import rewriter_config_pb2
45from tensorflow.python import _pywrap_stacktrace_handler
46from tensorflow.python import _pywrap_util_port
47from tensorflow.python import tf2
48from tensorflow.python.client import device_lib
49from tensorflow.python.client import pywrap_tf_session
50from tensorflow.python.client import session
51from tensorflow.python.compat.compat import forward_compatibility_horizon
52from tensorflow.python.eager import context
53from tensorflow.python.eager import def_function
54from tensorflow.python.eager import tape
55from tensorflow.python.framework import device as pydev
56from tensorflow.python.framework import dtypes
57from tensorflow.python.framework import errors
58from tensorflow.python.framework import errors_impl
59from tensorflow.python.framework import gpu_util
60from tensorflow.python.framework import importer
61from tensorflow.python.framework import ops
62from tensorflow.python.framework import random_seed
63from tensorflow.python.framework import sparse_tensor
64from tensorflow.python.framework import tensor_shape
65from tensorflow.python.framework import tensor_util
66from tensorflow.python.framework import versions
67from tensorflow.python.ops import array_ops
68from tensorflow.python.ops import control_flow_util
69from tensorflow.python.ops import control_flow_util_v2
70from tensorflow.python.ops.ragged import ragged_tensor
71from tensorflow.python.ops.ragged import ragged_tensor_value
72from tensorflow.python.ops import script_ops
73from tensorflow.python.ops import summary_ops_v2
74from tensorflow.python.ops import variables
75from tensorflow.python.platform import googletest
76from tensorflow.python.platform import tf_logging as logging
77from tensorflow.python.training import server_lib
78from tensorflow.python.util import compat
79from tensorflow.python.util import deprecation
80from tensorflow.python.util import nest
81from tensorflow.python.util import tf_decorator
82from tensorflow.python.util import tf_inspect
83from tensorflow.python.util.protobuf import compare
84from tensorflow.python.util.tf_export import tf_export
85from tensorflow.python.util.compat import collections_abc
86
87
88# If the below import is made available through the BUILD rule, then this
89# function is overridden and will instead return True and cause Tensorflow
90# graphs to be compiled with XLA.
91def is_xla_enabled():
92  return False
93
94
95try:
96  from tensorflow.python.framework.is_xla_test_true import is_xla_enabled  # pylint: disable=g-import-not-at-top
97except:
98  pass
99
100def _get_object_count_by_type():
101  return collections.Counter([type(obj).__name__ for obj in gc.get_objects()])
102
103@tf_export("test.gpu_device_name")
104def gpu_device_name():
105  """Returns the name of a GPU device if available or the empty string."""
106  for x in device_lib.list_local_devices():
107    if x.device_type == "GPU" or x.device_type == "SYCL":
108      return compat.as_str(x.name)
109  return ""
110
111
112def assert_ops_in_graph(expected_ops, graph):
113  """Assert all expected operations are found.
114
115  Args:
116    expected_ops: `dict<string, string>` of op name to op type.
117    graph: Graph to check.
118
119  Returns:
120    `dict<string, node>` of node name to node.
121
122  Raises:
123    ValueError: If the expected ops are not present in the graph.
124  """
125  actual_ops = {}
126  gd = graph.as_graph_def()
127  for node in gd.node:
128    if node.name in expected_ops:
129      if expected_ops[node.name] != node.op:
130        raise ValueError("Expected op for node %s is different. %s vs %s" %
131                         (node.name, expected_ops[node.name], node.op))
132      actual_ops[node.name] = node
133  if set(expected_ops.keys()) != set(actual_ops.keys()):
134    raise ValueError("Not all expected ops are present. Expected %s, found %s" %
135                     (expected_ops.keys(), actual_ops.keys()))
136  return actual_ops
137
138
139@tf_export("test.assert_equal_graph_def", v1=[])
140def assert_equal_graph_def_v2(expected, actual):
141  """Asserts that two `GraphDef`s are (mostly) the same.
142
143  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
144  nodes, attrs, and control inputs.  Node names are used to match up nodes
145  between the graphs, so the naming of nodes must be consistent. This function
146  ignores randomized attribute values that may appear in V2 checkpoints.
147
148  Args:
149    expected: The `GraphDef` we expected.
150    actual: The `GraphDef` we have.
151
152  Raises:
153    AssertionError: If the `GraphDef`s do not match.
154    TypeError: If either argument is not a `GraphDef`.
155  """
156  assert_equal_graph_def(actual, expected, checkpoint_v2=True,
157                         hash_table_shared_name=True)
158
159
160@tf_export(v1=["test.assert_equal_graph_def"])
161def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False,
162                              hash_table_shared_name=False):
163  """Asserts that two `GraphDef`s are (mostly) the same.
164
165  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
166  nodes, attrs, and control inputs.  Node names are used to match up nodes
167  between the graphs, so the naming of nodes must be consistent.
168
169  Args:
170    actual: The `GraphDef` we have.
171    expected: The `GraphDef` we expected.
172    checkpoint_v2: boolean determining whether to ignore randomized attribute
173      values that appear in V2 checkpoints.
174    hash_table_shared_name: boolean determining whether to ignore randomized
175      shared_names that appear in HashTableV2 op defs.
176
177  Raises:
178    AssertionError: If the `GraphDef`s do not match.
179    TypeError: If either argument is not a `GraphDef`.
180  """
181  assert_equal_graph_def(actual, expected, checkpoint_v2,
182                         hash_table_shared_name)
183
184
185def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
186                           hash_table_shared_name=False):
187  if not isinstance(actual, graph_pb2.GraphDef):
188    raise TypeError("Expected tf.GraphDef for actual, got %s" %
189                    type(actual).__name__)
190  if not isinstance(expected, graph_pb2.GraphDef):
191    raise TypeError("Expected tf.GraphDef for expected, got %s" %
192                    type(expected).__name__)
193
194  if checkpoint_v2:
195    _strip_checkpoint_v2_randomized(actual)
196    _strip_checkpoint_v2_randomized(expected)
197
198  if hash_table_shared_name:
199    _strip_hash_table_shared_name(actual)
200    _strip_hash_table_shared_name(expected)
201
202  diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
203                                                expected.SerializeToString())
204  if diff:
205    raise AssertionError(compat.as_str(diff))
206
207
208def assert_meta_graph_protos_equal(tester, a, b):
209  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
210  # Carefully check the collection_defs
211  tester.assertEqual(set(a.collection_def), set(b.collection_def))
212  collection_keys = a.collection_def.keys()
213  for k in collection_keys:
214    a_value = a.collection_def[k]
215    b_value = b.collection_def[k]
216    proto_type = ops.get_collection_proto_type(k)
217    if proto_type:
218      a_proto = proto_type()
219      b_proto = proto_type()
220      # Number of entries in the collections is the same
221      tester.assertEqual(
222          len(a_value.bytes_list.value), len(b_value.bytes_list.value))
223      for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
224                                              b_value.bytes_list.value):
225        a_proto.ParseFromString(a_value_item)
226        b_proto.ParseFromString(b_value_item)
227        tester.assertProtoEquals(a_proto, b_proto)
228    else:
229      tester.assertEquals(a_value, b_value)
230  # Compared the fields directly, remove their raw values from the
231  # proto comparison below.
232  a.ClearField("collection_def")
233  b.ClearField("collection_def")
234
235  # Check the graph_defs.
236  assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
237  # Check graph_def versions (ignored by assert_equal_graph_def).
238  tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
239  # Compared the fields directly, remove their raw values from the
240  # proto comparison below.
241  a.ClearField("graph_def")
242  b.ClearField("graph_def")
243
244  tester.assertProtoEquals(a, b)
245
246
247# Matches attributes named via _SHARDED_SUFFIX in
248# tensorflow/python/training/saver.py
249_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
250
251
252def _strip_checkpoint_v2_randomized(graph_def):
253  for node in graph_def.node:
254    delete_keys = []
255    for attr_key in node.attr:
256      attr_tensor_value = node.attr[attr_key].tensor
257      if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
258        attr_tensor_string_value = attr_tensor_value.string_val[0]
259        if (attr_tensor_string_value and
260            re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
261          delete_keys.append(attr_key)
262    for attr_key in delete_keys:
263      del node.attr[attr_key]
264
265
266_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+"
267
268
269def _strip_hash_table_shared_name(graph_def):
270  for node in graph_def.node:
271    delete_keys = []
272    if node.op == "HashTableV2" and "shared_name" in node.attr:
273      if re.match(_TABLE_SHARED_NAME_PATTERN, str(node.attr["shared_name"].s)):
274        delete_keys.append("shared_name")
275    for attr_key in delete_keys:
276      del node.attr[attr_key]
277
278
279def IsGoogleCudaEnabled():
280  return _pywrap_util_port.IsGoogleCudaEnabled()
281
282
283def IsBuiltWithROCm():
284  return _pywrap_util_port.IsBuiltWithROCm()
285
286
287def IsBuiltWithXLA():
288  return _pywrap_util_port.IsBuiltWithXLA()
289
290
291def IsBuiltWithNvcc():
292  return _pywrap_util_port.IsBuiltWithNvcc()
293
294
295def GpuSupportsHalfMatMulAndConv():
296  return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
297
298
299def IsMklEnabled():
300  return _pywrap_util_port.IsMklEnabled()
301
302
303def InstallStackTraceHandler():
304  _pywrap_stacktrace_handler.InstallStacktraceHandler()
305
306
307def NHWCToNCHW(input_tensor):
308  """Converts the input from the NHWC format to NCHW.
309
310  Args:
311    input_tensor: a 4- or 5-D tensor, or an array representing shape
312
313  Returns:
314    converted tensor or shape array
315  """
316  # tensor dim -> new axis order
317  new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
318  if isinstance(input_tensor, ops.Tensor):
319    ndims = input_tensor.shape.ndims
320    return array_ops.transpose(input_tensor, new_axes[ndims])
321  else:
322    ndims = len(input_tensor)
323    return [input_tensor[a] for a in new_axes[ndims]]
324
325
326def NHWCToNCHW_VECT_C(input_shape_or_tensor):
327  """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
328
329  Note: Does not include quantization or type conversion steps, which should
330  be applied afterwards.
331
332  Args:
333    input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
334
335  Returns:
336    tensor or shape array transformed into NCHW_VECT_C
337
338  Raises:
339    ValueError: if last dimension of `input_shape_or_tensor` is not evenly
340        divisible by 4.
341  """
342  permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
343  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
344  temp_shape = (
345      input_shape_or_tensor.shape.as_list()
346      if is_tensor else input_shape_or_tensor)
347  if temp_shape[-1] % 4 != 0:
348    raise ValueError(
349        "Last dimension of input must be evenly divisible by 4 to convert to "
350        "NCHW_VECT_C.")
351  temp_shape[-1] //= 4
352  temp_shape.append(4)
353  permutation = permutations[len(temp_shape)]
354  if is_tensor:
355    t = array_ops.reshape(input_shape_or_tensor, temp_shape)
356    return array_ops.transpose(t, permutation)
357  else:
358    return [temp_shape[a] for a in permutation]
359
360
361def NCHW_VECT_CToNHWC(input_shape_or_tensor):
362  """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
363
364  Note: Does not include de-quantization or type conversion steps, which should
365  be applied beforehand.
366
367  Args:
368    input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
369
370  Returns:
371    tensor or shape array transformed into NHWC
372
373  Raises:
374    ValueError: if last dimension of `input_shape_or_tensor` is not 4.
375  """
376  permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
377  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
378  input_shape = (
379      input_shape_or_tensor.shape.as_list()
380      if is_tensor else input_shape_or_tensor)
381  if input_shape[-1] != 4:
382    raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
383  permutation = permutations[len(input_shape)]
384  nhwc_shape = [input_shape[a] for a in permutation[:-1]]
385  nhwc_shape[-1] *= input_shape[-1]
386  if is_tensor:
387    t = array_ops.transpose(input_shape_or_tensor, permutation)
388    return array_ops.reshape(t, nhwc_shape)
389  else:
390    return nhwc_shape
391
392
393def NCHWToNHWC(input_tensor):
394  """Converts the input from the NCHW format to NHWC.
395
396  Args:
397    input_tensor: a 4- or 5-D tensor, or an array representing shape
398
399  Returns:
400    converted tensor or shape array
401  """
402  # tensor dim -> new axis order
403  new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
404  if isinstance(input_tensor, ops.Tensor):
405    ndims = input_tensor.shape.ndims
406    return array_ops.transpose(input_tensor, new_axes[ndims])
407  else:
408    ndims = len(input_tensor)
409    return [input_tensor[a] for a in new_axes[ndims]]
410
411
412def skip_if(condition):
413  """Skips the decorated function if condition is or evaluates to True.
414
415  Args:
416    condition: Either an expression that can be used in "if not condition"
417      statement, or a callable whose result should be a boolean.
418
419  Returns:
420    The wrapped function
421  """
422
423  def real_skip_if(fn):
424
425    def wrapper(*args, **kwargs):
426      if callable(condition):
427        skip = condition()
428      else:
429        skip = condition
430      if not skip:
431        return fn(*args, **kwargs)
432
433    return wrapper
434
435  return real_skip_if
436
437
438def enable_c_shapes(fn):
439  """No-op. TODO(b/74620627): Remove this."""
440  return fn
441
442
443def with_c_shapes(cls):
444  """No-op. TODO(b/74620627): Remove this."""
445  return cls
446
447
448def enable_control_flow_v2(fn):
449  """Decorator for enabling CondV2 and WhileV2 on a test.
450
451  Note this enables using CondV2 and WhileV2 after running the test class's
452  setup/teardown methods.
453
454  In addition to this, callers must import the while_v2 module in order to set
455  the _while_v2 module in control_flow_ops.
456
457  Args:
458    fn: the function to be wrapped
459
460  Returns:
461    The wrapped function
462  """
463
464  def wrapper(*args, **kwargs):
465    enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
466    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
467    try:
468      return fn(*args, **kwargs)
469    finally:
470      control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
471
472  return wrapper
473
474
475def with_control_flow_v2(cls):
476  """Adds methods that call original methods with WhileV2 and CondV2 enabled.
477
478  Note this enables CondV2 and WhileV2 in new methods after running the test
479  class's setup method.
480
481  In addition to this, callers must import the while_v2 module in order to set
482  the _while_v2 module in control_flow_ops.
483
484  If a test function has _disable_control_flow_v2 attr set to True (using the
485  @disable_control_flow_v2 decorator), the v2 function is not generated for it.
486
487  Example:
488
489  @test_util.with_control_flow_v2
490  class ControlFlowTest(test.TestCase):
491
492    def testEnabledForV2(self):
493      ...
494
495    @test_util.disable_control_flow_v2("b/xyzabc")
496    def testDisabledForV2(self):
497      ...
498
499  Generated class:
500  class ControlFlowTest(test.TestCase):
501
502    def testEnabledForV2(self):
503      ...
504
505    def testEnabledForV2WithControlFlowV2(self):
506      // Enable V2 flags.
507      testEnabledForV2(self)
508      // Restore V2 flags.
509
510    def testDisabledForV2(self):
511      ...
512
513  Args:
514    cls: class to decorate
515
516  Returns:
517    cls with new test methods added
518  """
519  if control_flow_util.ENABLE_CONTROL_FLOW_V2:
520    return cls
521
522  for name, value in cls.__dict__.copy().items():
523    if (callable(value) and
524        name.startswith(unittest.TestLoader.testMethodPrefix) and
525        not getattr(value, "_disable_control_flow_v2", False)):
526      setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
527  return cls
528
529
530def disable_control_flow_v2(unused_msg):
531  """Decorator for a function in a with_control_flow_v2 enabled test class.
532
533  Blocks the function from being run with v2 control flow ops.
534
535  Args:
536    unused_msg: Reason for disabling.
537
538  Returns:
539    The wrapped function with _disable_control_flow_v2 attr set to True.
540  """
541
542  def wrapper(func):
543    func._disable_control_flow_v2 = True
544    return func
545
546  return wrapper
547
548
549def enable_output_all_intermediates(fn):
550  """Force-enable outputing all intermediates from functional control flow ops.
551
552  Args:
553    fn: the function to be wrapped
554
555  Returns:
556    The wrapped function
557  """
558
559  def wrapper(*args, **kwargs):
560    output_all_intermediates_old = \
561        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
562    control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
563    try:
564      return fn(*args, **kwargs)
565    finally:
566      control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
567          output_all_intermediates_old
568
569  return wrapper
570
571
572def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
573  """Decorator for asserting that no new Python objects persist after a test.
574
575  Runs the test multiple times executing eagerly, first as a warmup and then to
576  let objects accumulate. The warmup helps ignore caches which do not grow as
577  the test is run repeatedly.
578
579  Useful for checking that there are no missing Py_DECREFs in the C exercised by
580  a bit of Python.
581
582  Args:
583    func: The function to test.
584    warmup_iters: The numer of warmup iterations, excluded from measuring.
585
586  Returns:
587    The wrapped function performing the test.
588  """
589
590  def wrap_f(f):
591    def decorator(self, *args, **kwargs):
592      """Warms up, gets object counts, runs the test, checks for new objects."""
593      with context.eager_mode():
594        gc.disable()
595        # Run the test 2 times as warmup, in an attempt to fill up caches, which
596        # should not grow as the test is run repeatedly below.
597        #
598        # TODO(b/117156879): Running warmup twice is black magic; we have seen
599        # tests that fail with 1 warmup run, and pass with 2, on various
600        # versions of python2.7.x.
601        for _ in range(warmup_iters):
602          f(self, *args, **kwargs)
603
604        # Some objects are newly created by _get_object_count_by_type().  So
605        # create and save as a dummy variable to include it as a baseline.
606        obj_count_by_type = _get_object_count_by_type()
607        gc.collect()
608        obj_count_by_type = _get_object_count_by_type()
609
610        if ops.has_default_graph():
611          collection_sizes_before = {
612              collection: len(ops.get_collection(collection))
613              for collection in ops.get_default_graph().collections
614          }
615        for _ in range(3):
616          f(self, *args, **kwargs)
617        # Note that gc.get_objects misses anything that isn't subject to garbage
618        # collection (C types). Collections are a common source of leaks, so we
619        # test for collection sizes explicitly.
620        if ops.has_default_graph():
621          for collection_key in ops.get_default_graph().collections:
622            collection = ops.get_collection(collection_key)
623            size_before = collection_sizes_before.get(collection_key, 0)
624            if len(collection) > size_before:
625              raise AssertionError(
626                  ("Collection %s increased in size from "
627                   "%d to %d (current items %s).") %
628                  (collection_key, size_before, len(collection), collection))
629            # Make sure our collection checks don't show up as leaked memory by
630            # removing references to temporary variables.
631            del collection
632            del collection_key
633            del size_before
634          del collection_sizes_before
635        gc.collect()
636
637        # There should be no new Python objects hanging around.
638        obj_count_by_type = _get_object_count_by_type() - obj_count_by_type
639        # In some cases (specifacally on MacOS), new_count is somehow
640        # smaller than previous_count.
641        # Using plain assert because not all classes using this decorator
642        # have assertLessEqual
643        assert not obj_count_by_type, (
644            "The following objects were newly created: %s" %
645            str(obj_count_by_type))
646        gc.enable()
647    return decorator
648
649  if func is None:
650    return wrap_f
651  else:
652    return wrap_f(func)
653
654
655def assert_no_new_tensors(f):
656  """Decorator for asserting that no new Tensors persist after a test.
657
658  Mainly useful for checking that code using the Python C API has correctly
659  manipulated reference counts.
660
661  Clears the caches that it knows about, runs the garbage collector, then checks
662  that there are no Tensor or Tensor-like objects still around. This includes
663  Tensors to which something still has a reference (e.g. from missing
664  Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
665  of the objects has __del__ defined).
666
667  Args:
668    f: The test case to run.
669
670  Returns:
671    The decorated test case.
672  """
673
674  def decorator(self, **kwargs):
675    """Finds existing Tensors, runs the test, checks for new Tensors."""
676
677    def _is_tensorflow_object(obj):
678      try:
679        return isinstance(obj,
680                          (ops.Tensor, variables.Variable,
681                           tensor_shape.Dimension, tensor_shape.TensorShape))
682      except ReferenceError:
683        # If the object no longer exists, we don't care about it.
684        return False
685
686    tensors_before = set(
687        id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
688    outside_executed_eagerly = context.executing_eagerly()
689    # Run the test in a new graph so that collections get cleared when it's
690    # done, but inherit the graph key so optimizers behave.
691    outside_graph_key = ops.get_default_graph()._graph_key
692    with ops.Graph().as_default():
693      ops.get_default_graph()._graph_key = outside_graph_key
694      if outside_executed_eagerly:
695        with context.eager_mode():
696          result = f(self, **kwargs)
697      else:
698        result = f(self, **kwargs)
699    # Make an effort to clear caches, which would otherwise look like leaked
700    # Tensors.
701    context.context()._clear_caches()  # pylint: disable=protected-access
702    gc.collect()
703    tensors_after = [
704        obj for obj in gc.get_objects()
705        if _is_tensorflow_object(obj) and id(obj) not in tensors_before
706    ]
707    if tensors_after:
708      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
709          len(tensors_after),
710          str(tensors_after),
711      )))
712    return result
713
714  return decorator
715
716
717def _find_reference_cycle(objects, idx):
718
719  def get_ignore_reason(obj, blacklist):
720    """Tests whether an object should be omitted from the dependency graph."""
721    if len(blacklist) > 100:
722      return "<depth limit>"
723    if tf_inspect.isframe(obj):
724      if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
725        return "<test code>"
726    for b in blacklist:
727      if b is obj:
728        return "<test code>"
729    if obj is blacklist:
730      return "<test code>"
731    return None
732
733  # Note: this function is meant to help with diagnostics. Its output is purely
734  # a human-readable representation, so you may freely modify it to suit your
735  # needs.
736  def describe(obj, blacklist, leaves_only=False):
737    """Returns a custom human-readable summary of obj.
738
739    Args:
740      obj: the value to describe.
741      blacklist: same as blacklist in get_ignore_reason.
742      leaves_only: boolean flag used when calling describe recursively. Useful
743        for summarizing collections.
744    """
745    if get_ignore_reason(obj, blacklist):
746      return "{}{}".format(get_ignore_reason(obj, blacklist), type(obj))
747    if tf_inspect.isframe(obj):
748      return "frame: {}".format(tf_inspect.getframeinfo(obj))
749    elif tf_inspect.ismodule(obj):
750      return "module: {}".format(obj.__name__)
751    else:
752      if leaves_only:
753        return "{}, {}".format(type(obj), id(obj))
754      elif isinstance(obj, list):
755        return "list({}): {}".format(
756            id(obj), [describe(e, blacklist, leaves_only=True) for e in obj])
757      elif isinstance(obj, tuple):
758        return "tuple({}): {}".format(
759            id(obj), [describe(e, blacklist, leaves_only=True) for e in obj])
760      elif isinstance(obj, dict):
761        return "dict({}): {} keys".format(id(obj), len(obj.keys()))
762      elif tf_inspect.isfunction(obj):
763        return "function({}) {}; globals ID: {}".format(
764            id(obj), obj.__name__, id(obj.__globals__))
765      else:
766        return "{}, {}".format(type(obj), id(obj))
767
768  def build_ref_graph(obj, graph, reprs, blacklist):
769    """Builds a reference graph as <referrer> -> <list of refferents>.
770
771    Args:
772      obj: The object to start from. The graph will be built by recursively
773        adding its referrers.
774      graph: Dict holding the graph to be built. To avoid creating extra
775        references, the graph holds object IDs rather than actual objects.
776      reprs: Auxiliary structure that maps object IDs to their human-readable
777        description.
778      blacklist: List of objects to ignore.
779    """
780    referrers = gc.get_referrers(obj)
781    blacklist = blacklist + (referrers,)
782
783    obj_id = id(obj)
784    for r in referrers:
785      if get_ignore_reason(r, blacklist) is None:
786        r_id = id(r)
787        if r_id not in graph:
788          graph[r_id] = []
789        if obj_id not in graph[r_id]:
790          graph[r_id].append(obj_id)
791          build_ref_graph(r, graph, reprs, blacklist)
792          reprs[r_id] = describe(r, blacklist)
793
794  def find_cycle(el, graph, reprs, path):
795    """Finds and prints a single cycle in the dependency graph."""
796    if el not in graph:
797      return
798    for r in graph[el]:
799      if r in path:
800        logging.error("Reference cycle sample:")
801        for p in path + (r,):
802          logging.error(reprs.get(p, "unknown object " + str(p)))
803        return True
804      else:
805        if find_cycle(r, graph, reprs, path + (r,)):
806          return True
807    return False
808
809  obj = objects[idx]
810  graph = {}  # referrer ID -> object ID
811  reprs = {}  # object ID -> description
812  build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason,
813                                      describe, build_ref_graph, find_cycle))
814  for k in graph:
815    if find_cycle(k, graph, reprs, ()):
816      return True
817  return False
818
819
820def assert_no_garbage_created(f):
821  """Test method decorator to assert that no garbage has been created.
822
823  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
824  cannot be un-set (i.e. will disable garbage collection for any other unit
825  tests in the same file/shard).
826
827  Args:
828    f: The function to decorate.
829
830  Returns:
831    The decorated function.
832  """
833
834  def decorator(self, **kwargs):
835    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
836    # Force-load `distribution_strategy_context` to prevent GC at
837    # test time when using eager. Remove once b/117329403 is resolved.
838    tape.distribution_strategy_context.get_strategy()
839
840    gc.disable()
841    previous_debug_flags = gc.get_debug()
842    gc.set_debug(gc.DEBUG_SAVEALL)
843    gc.collect()
844    previous_garbage = len(gc.garbage)
845    result = f(self, **kwargs)
846    gc.collect()
847    new_garbage = len(gc.garbage)
848    if new_garbage > previous_garbage:
849      logging.error(
850          "The decorated test created work for Python's garbage collector, "
851          "likely due to a reference cycle. New objects in cycle(s):")
852      for i, obj in enumerate(gc.garbage[previous_garbage:]):
853        try:
854          logging.error("Object %d of %d", i,
855                        len(gc.garbage) - previous_garbage)
856
857          def _safe_object_str(obj):
858            return "<%s %d>" % (obj.__class__.__name__, id(obj))
859
860          logging.error("  Object type: %s", _safe_object_str(obj))
861          logging.error(
862              "  Referrer types: %s", ", ".join(
863                  [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
864          logging.error(
865              "  Referent types: %s", ", ".join(
866                  [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
867          logging.error("  Object attribute names: %s", dir(obj))
868          logging.error("  Object __str__:")
869          logging.error(obj)
870          logging.error("  Object __repr__:")
871          logging.error(repr(obj))
872        except Exception:  # pylint: disable=broad-except
873          logging.error("(Exception while printing object)")
874
875    # When garbage is created, this call can help identify reference cycles,
876    # which are typically the cause of such garbage.
877    if new_garbage > previous_garbage:
878      for i in range(previous_garbage, new_garbage):
879        if _find_reference_cycle(gc.garbage, i):
880          break
881
882    # This will fail if any garbage has been created, typically because of a
883    # reference cycle.
884    self.assertEqual(previous_garbage, new_garbage)
885    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
886    # be nice to be able to decorate arbitrary tests in a large test suite and
887    # not hold on to every object in other tests.
888    gc.set_debug(previous_debug_flags)
889    gc.enable()
890    return result
891
892  return decorator
893
894
895def _combine_named_parameters(**kwargs):
896  """Generate combinations based on its keyword arguments.
897
898  Two sets of returned combinations can be concatenated using +.  Their product
899  can be computed using `times()`.
900
901  Args:
902    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
903      `option=the_only_possibility`.
904
905  Returns:
906    a list of dictionaries for each combination. Keys in the dictionaries are
907    the keyword argument names.  Each key has one value - one of the
908    corresponding keyword argument values.
909  """
910  sort_by_key = lambda k: k[0]
911  combinations = []
912  for key, values in sorted(kwargs.items(), key=sort_by_key):
913    if not isinstance(values, list):
914      values = [values]
915    combinations.append([(key, value) for value in values])
916
917  return [OrderedDict(result) for result in itertools.product(*combinations)]
918
919
920def generate_combinations_with_testcase_name(**kwargs):
921  """Generate combinations based on its keyword arguments using combine().
922
923  This function calls combine() and appends a testcase name to the list of
924  dictionaries returned. The 'testcase_name' key is a required for named
925  parameterized tests.
926
927  Args:
928    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
929      `option=the_only_possibility`.
930
931  Returns:
932    a list of dictionaries for each combination. Keys in the dictionaries are
933    the keyword argument names.  Each key has one value - one of the
934    corresponding keyword argument values.
935  """
936  combinations = _combine_named_parameters(**kwargs)
937  named_combinations = []
938  for combination in combinations:
939    assert isinstance(combination, OrderedDict)
940    name = "".join([
941        "_{}_{}".format("".join(filter(str.isalnum, key)),
942                        "".join(filter(str.isalnum, str(value))))
943        for key, value in combination.items()
944    ])
945    named_combinations.append(
946        OrderedDict(
947            list(combination.items()) +
948            [("testcase_name", "_test{}".format(name))]))
949
950  return named_combinations
951
952
953def run_all_in_graph_and_eager_modes(cls):
954  """Execute all test methods in the given class with and without eager."""
955  base_decorator = run_in_graph_and_eager_modes
956  for name in dir(cls):
957    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
958        name.startswith("testSkipEager") or
959        name.startswith("test_skip_eager") or
960        name == "test_session"):
961      continue
962    value = getattr(cls, name, None)
963    if callable(value):
964      setattr(cls, name, base_decorator(value))
965  return cls
966
967
968def build_as_function_and_v1_graph(func=None):
969  """Run a test case in v1 graph mode and inside tf.function in eager mode.
970
971  WARNING: This decorator can only be used in test cases that statically checks
972  generated graph. Attempting to evaluate graph or function results via.
973  session.run() or self.evaluate() will fail.
974
975  WARNING: This decorator can only be used for test cases that inherit from
976  absl.testing.parameterized.TestCase.
977
978  Args:
979    func: Test case function to be decorated.
980
981  Returns:
982    Decorated test case function.
983  """
984
985  def decorator(f):
986    if tf_inspect.isclass(f):
987      raise ValueError(
988          "`run_in_graph_mode_and_function` only supports test methods.")
989
990    @parameterized.named_parameters(("_v1_graph", "v1_graph"),
991                                    ("_function", "function"))
992    @functools.wraps(f)
993    def decorated(self, run_mode, *args, **kwargs):
994      if run_mode == "v1_graph":
995        with ops.Graph().as_default():
996          f(self, *args, **kwargs)
997      elif run_mode == "function":
998
999        @def_function.function
1000        def function_in_eager():
1001          f(self, *args, **kwargs)
1002
1003        # Create a new graph for the eagerly executed version of this test for
1004        # better isolation.
1005        graph_for_eager_test = ops.Graph()
1006        with graph_for_eager_test.as_default(), context.eager_mode():
1007          function_in_eager()
1008        ops.dismantle_graph(graph_for_eager_test)
1009      else:
1010        return ValueError("Unknown run mode %s" % run_mode)
1011
1012    return decorated
1013
1014  if func is not None:
1015    return decorator(func)
1016
1017  return decorator
1018
1019
1020def eager_lazy_remote_copy_on_and_off(f):
1021  """Execute the test method w/o lazy tensor copy for function remote inputs."""
1022
1023  @parameterized.named_parameters([("WithLazyRemoteCopy", True), ("", False)])
1024  @functools.wraps(f)
1025  def decorator(self, lazily_remote_copy, *args, **kwargs):
1026    if lazily_remote_copy:
1027      context.context().lazy_remote_inputs_copy = True
1028    else:
1029      context.context().lazy_remote_inputs_copy = False
1030    f(self, *args, **kwargs)
1031
1032  return decorator
1033
1034
1035def run_in_graph_and_eager_modes(func=None,
1036                                 config=None,
1037                                 use_gpu=True,
1038                                 reset_test=True,
1039                                 assert_no_eager_garbage=False):
1040  """Execute the decorated test with and without enabling eager execution.
1041
1042  This function returns a decorator intended to be applied to test methods in
1043  a `tf.test.TestCase` class. Doing so will cause the contents of the test
1044  method to be executed twice - once normally, and once with eager execution
1045  enabled. This allows unittests to confirm the equivalence between eager
1046  and graph execution (see `tf.compat.v1.enable_eager_execution`).
1047
1048  For example, consider the following unittest:
1049
1050  ```python
1051  class MyTests(tf.test.TestCase):
1052
1053    @run_in_graph_and_eager_modes
1054    def test_foo(self):
1055      x = tf.constant([1, 2])
1056      y = tf.constant([3, 4])
1057      z = tf.add(x, y)
1058      self.assertAllEqual([4, 6], self.evaluate(z))
1059
1060  if __name__ == "__main__":
1061    tf.test.main()
1062  ```
1063
1064  This test validates that `tf.add()` has the same behavior when computed with
1065  eager execution enabled as it does when constructing a TensorFlow graph and
1066  executing the `z` tensor in a session.
1067
1068  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1069  `run_in_graph_and_eager_modes` are available decorators for different
1070  v1/v2/eager/graph combinations.
1071
1072
1073  Args:
1074    func: function to be annotated. If `func` is None, this method returns a
1075      decorator the can be applied to a function. If `func` is not None this
1076      returns the decorator applied to `func`.
1077    config: An optional config_pb2.ConfigProto to use to configure the session
1078      when executing graphs.
1079    use_gpu: If True, attempt to run as many operations as possible on GPU.
1080    reset_test: If True, tearDown and SetUp the test case between the two
1081      executions of the test (once with and once without eager execution).
1082    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
1083      collector and asserts that no extra garbage has been created when running
1084      the test with eager execution enabled. This will fail if there are
1085      reference cycles (e.g. a = []; a.append(a)). Off by default because some
1086      tests may create garbage for legitimate reasons (e.g. they define a class
1087      which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
1088      Python interpreters (meaning that tests which rely on objects being
1089      collected elsewhere in the unit test file will not work). Additionally,
1090      checks that nothing still has a reference to Tensors that the test
1091      allocated.
1092
1093  Returns:
1094    Returns a decorator that will run the decorated test method twice:
1095    once by constructing and executing a graph in a session and once with
1096    eager execution enabled.
1097  """
1098
1099  def decorator(f):
1100    if tf_inspect.isclass(f):
1101      raise ValueError(
1102          "`run_in_graph_and_eager_modes` only supports test methods. "
1103          "Did you mean to use `run_all_in_graph_and_eager_modes`?")
1104
1105    def decorated(self, *args, **kwargs):
1106      try:
1107        with context.graph_mode():
1108          with self.test_session(use_gpu=use_gpu, config=config):
1109            f(self, *args, **kwargs)
1110      except unittest.case.SkipTest:
1111        pass
1112
1113      def run_eagerly(self, **kwargs):
1114        if not use_gpu:
1115          with ops.device("/device:CPU:0"):
1116            f(self, *args, **kwargs)
1117        else:
1118          f(self, *args, **kwargs)
1119
1120      if assert_no_eager_garbage:
1121        ops.reset_default_graph()
1122        run_eagerly = assert_no_new_tensors(
1123            assert_no_garbage_created(run_eagerly))
1124
1125      if reset_test:
1126        # This decorator runs the wrapped test twice.
1127        # Reset the test environment between runs.
1128        self.tearDown()
1129        self._tempdir = None
1130      # Create a new graph for the eagerly executed version of this test for
1131      # better isolation.
1132      graph_for_eager_test = ops.Graph()
1133      with graph_for_eager_test.as_default(), context.eager_mode():
1134        if reset_test:
1135          self.setUp()
1136        run_eagerly(self, **kwargs)
1137      ops.dismantle_graph(graph_for_eager_test)
1138
1139    return decorated
1140
1141  if func is not None:
1142    return decorator(func)
1143
1144  return decorator
1145
1146
1147def py_func_if_in_function(f):
1148
1149  def decorated(*args, **kwds):
1150    if not ops.get_default_graph()._building_function:
1151      return f(*args, **kwds)
1152
1153    tensor_args = []
1154    tensor_indices = []
1155    for i, arg in enumerate(args):
1156      if isinstance(arg, (ops.Tensor, variables.Variable)):
1157        tensor_args.append(arg)
1158        tensor_indices.append(i)
1159
1160    def inner_f(*inner_tensor_args):
1161      my_args = list(args)
1162      for i, n in zip(tensor_indices, inner_tensor_args):
1163        my_args[i] = n
1164      return f(*my_args, **kwds)
1165
1166    return script_ops.py_func(inner_f, tensor_args, [])
1167
1168  return tf_decorator.make_decorator(f, decorated)
1169
1170
1171def also_run_as_tf_function(f):
1172  """Runs the decorated test twice--once as is, once inside a tf.function.
1173
1174  This allows you to run a test both in eager execution and inside a
1175  tf.function, exercising the two execution modes supported in tf 2.0. The test
1176  assertions are automatically done inside tf.py_funcs, and tf.function ensures
1177  that they run in the proper order and with the proper side effects.
1178
1179  Currently variable creation is not supported in tests annotated with this
1180  decorator since it's tricky to ensure the variable doesn't get repeatedly
1181  created when retracing the tf.function.
1182
1183  Args:
1184    f: the test method to be decorated
1185
1186  Returns:
1187    The decorated test method, which will run both in eager and inside a
1188    tf.function.
1189  """
1190
1191  def decorated(*args, **kwds):
1192
1193    def bound_f():
1194      f(*args, **kwds)
1195
1196    with context.eager_mode():
1197      # Running in eager mode
1198      bound_f()
1199      # Running as TF function
1200      # TODO(b/121143941): Remove the autograph override.
1201      def_function.function(bound_f, autograph=False)()
1202
1203  return decorated
1204
1205
1206def deprecated_graph_mode_only(func=None):
1207  """Execute the decorated test in graph mode.
1208
1209  This function returns a decorator intended to be applied to tests that are not
1210  compatible with eager mode. When this decorator is applied, the test body will
1211  be run in an environment where API calls construct graphs instead of executing
1212  eagerly.
1213
1214  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1215  `run_in_graph_and_eager_modes` are available decorators for different
1216  v1/v2/eager/graph combinations.
1217
1218  Args:
1219    func: function to be annotated. If `func` is None, this method returns a
1220      decorator the can be applied to a function. If `func` is not None this
1221      returns the decorator applied to `func`.
1222
1223  Returns:
1224    Returns a decorator that will run the decorated test method in graph mode.
1225  """
1226
1227  def decorator(f):
1228    if tf_inspect.isclass(f):
1229      setup = f.__dict__.get("setUp")
1230      if setup is not None:
1231        setattr(f, "setUp", decorator(setup))
1232
1233      for name, value in f.__dict__.copy().items():
1234        if (callable(value) and
1235            name.startswith(unittest.TestLoader.testMethodPrefix)):
1236          setattr(f, name, decorator(value))
1237
1238      return f
1239
1240    def decorated(self, *args, **kwargs):
1241      if context.executing_eagerly():
1242        with context.graph_mode():
1243          return f(self, *args, **kwargs)
1244      else:
1245        return f(self, *args, **kwargs)
1246
1247    return decorated
1248
1249  if func is not None:
1250    return decorator(func)
1251
1252  return decorator
1253
1254
1255run_deprecated_v1 = deprecated_graph_mode_only
1256
1257
1258def run_all_in_deprecated_graph_mode_only(cls):
1259  """Execute all tests in a class in graph mode."""
1260  base_decorator = deprecated_graph_mode_only
1261  for name in dir(cls):
1262    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1263        name == "test_session"):
1264      continue
1265    value = getattr(cls, name, None)
1266    if callable(value):
1267      setattr(cls, name, base_decorator(value))
1268  return cls
1269
1270
1271def run_v1_only(reason, func=None):
1272  """Execute the decorated test only if running in v1 mode.
1273
1274  This function is intended to be applied to tests that exercise v1 only
1275  functionality. If the test is run in v2 mode it will simply be skipped.
1276
1277  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1278  `run_in_graph_and_eager_modes` are available decorators for different
1279  v1/v2/eager/graph combinations.
1280
1281  Args:
1282    reason: string giving a reason for limiting the test to v1 only.
1283    func: function to be annotated. If `func` is None, this method returns a
1284      decorator the can be applied to a function. If `func` is not None this
1285      returns the decorator applied to `func`.
1286
1287  Returns:
1288    Returns a decorator that will conditionally skip the decorated test method.
1289  """
1290  if not isinstance(reason, str):
1291    raise ValueError("'reason' should be string, got {}".format(type(reason)))
1292
1293  def decorator(f):
1294    if tf_inspect.isclass(f):
1295      # To skip an entire test suite class, we only decorate the setUp method
1296      # to skip all tests. There are cases when setUp is not defined (not
1297      # overridden in subclasses of TestCase, so not available in f.__dict__
1298      # below). For those cases, we walk the method resolution order list and
1299      # pick the first setUp method we find (usually this should be the one in
1300      # the parent class since that's the TestCase class).
1301      for cls in type.mro(f):
1302        setup = cls.__dict__.get("setUp")
1303        if setup is not None:
1304          setattr(f, "setUp", decorator(setup))
1305          break
1306
1307      return f
1308    else:
1309      # If f is just a function, just create a decorator for it and return it
1310      def decorated(self, *args, **kwargs):
1311        if tf2.enabled():
1312          self.skipTest(reason)
1313
1314        return f(self, *args, **kwargs)
1315
1316      return decorated
1317
1318  if func is not None:
1319    return decorator(func)
1320
1321  return decorator
1322
1323
1324def run_v2_only(func=None):
1325  """Execute the decorated test only if running in v2 mode.
1326
1327  This function is intended to be applied to tests that exercise v2 only
1328  functionality. If the test is run in v1 mode it will simply be skipped.
1329
1330  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1331  `run_in_graph_and_eager_modes` are available decorators for different
1332  v1/v2/eager/graph combinations.
1333
1334  Args:
1335    func: function to be annotated. If `func` is None, this method returns a
1336      decorator the can be applied to a function. If `func` is not None this
1337      returns the decorator applied to `func`.
1338
1339  Returns:
1340    Returns a decorator that will conditionally skip the decorated test method.
1341  """
1342
1343  def decorator(f):
1344    if tf_inspect.isclass(f):
1345      raise ValueError("`run_v2_only` only supports test methods.")
1346
1347    def decorated(self, *args, **kwargs):
1348      if not tf2.enabled():
1349        self.skipTest("Test is only compatible with v2")
1350
1351      return f(self, *args, **kwargs)
1352
1353    return decorated
1354
1355  if func is not None:
1356    return decorator(func)
1357
1358  return decorator
1359
1360
1361def run_gpu_only(func=None):
1362  """Execute the decorated test only if a GPU is available.
1363
1364  This function is intended to be applied to tests that require the presence
1365  of a GPU. If a GPU is absent, it will simply be skipped.
1366
1367  Args:
1368    func: function to be annotated. If `func` is None, this method returns a
1369      decorator the can be applied to a function. If `func` is not None this
1370      returns the decorator applied to `func`.
1371
1372  Returns:
1373    Returns a decorator that will conditionally skip the decorated test method.
1374  """
1375
1376  def decorator(f):
1377    if tf_inspect.isclass(f):
1378      raise ValueError("`run_gpu_only` only supports test methods.")
1379
1380    def decorated(self, *args, **kwargs):
1381      if not is_gpu_available():
1382        self.skipTest("Test requires GPU")
1383
1384      return f(self, *args, **kwargs)
1385
1386    return decorated
1387
1388  if func is not None:
1389    return decorator(func)
1390
1391  return decorator
1392
1393
1394def run_cuda_only(func=None):
1395  """Execute the decorated test only if a GPU is available.
1396
1397  This function is intended to be applied to tests that require the precense
1398  of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped.
1399
1400  Args:
1401    func: function to be annotated. If `func` is None, this method returns a
1402      decorator the can be applied to a function. If `func` is not None this
1403      returns the decorator applied to `func`.
1404
1405  Returns:
1406    Returns a decorator that will conditionally skip the decorated test method.
1407  """
1408
1409  def decorator(f):
1410    if tf_inspect.isclass(f):
1411      raise ValueError("`run_cuda_only` only supports test methods.")
1412
1413    def decorated(self, *args, **kwargs):
1414      if not is_gpu_available(cuda_only=True):
1415        self.skipTest("Test requires CUDA GPU")
1416
1417      return f(self, *args, **kwargs)
1418
1419    return decorated
1420
1421  if func is not None:
1422    return decorator(func)
1423
1424  return decorator
1425
1426
1427def with_forward_compatibility_horizons(*horizons):
1428  """Executes the decorated test with the specified forward-compat horizons.
1429
1430  Args:
1431    *horizons: A list of (year, month, day) tuples.  If the list includes
1432      `None`, then the test will also be run with no forward-compatibility
1433      horizon set.
1434
1435  Returns:
1436    A decorator that will execute the test with the specified horizons.
1437  """
1438  if not horizons:
1439    raise ValueError("Expected at least one horizon.")
1440  for horizon in horizons:
1441    if not ((horizon is None) or
1442            (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))):
1443      raise ValueError("Bad horizon value: %r" % horizon)
1444
1445  def decorator(f):
1446    if tf_inspect.isclass(f):
1447      raise ValueError("`with_forward_compatibility_horizons` only "
1448                       "supports test methods.")
1449    def decorated(self, *args, **kwargs):
1450      for horizon in horizons:
1451        if horizon is None:
1452          f(self, *args, **kwargs)
1453        else:
1454          (year, month, day) = horizon
1455          with forward_compatibility_horizon(year, month, day):
1456            f(self, *args, **kwargs)
1457    return decorated
1458
1459  return decorator
1460
1461
1462@deprecation.deprecated(None,
1463                        "Use `tf.config.list_physical_devices('GPU')` instead.")
1464@tf_export("test.is_gpu_available")
1465def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
1466  """Returns whether TensorFlow can access a GPU.
1467
1468  Warning: if a non-GPU version of the package is installed, the function would
1469  also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow
1470  was build with CUDA support.
1471
1472  Args:
1473    cuda_only: limit the search to CUDA GPUs.
1474    min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
1475      CUDA compute capability required, or None if no requirement.
1476
1477  Note that the keyword arg name "cuda_only" is misleading (since routine will
1478  return true when a GPU device is available irrespective of whether TF was
1479  built with CUDA support or ROCm support. However no changes here because
1480
1481  ++ Changing the name "cuda_only" to something more generic would break
1482     backward compatibility
1483
1484  ++ Adding an equivalent "rocm_only" would require the implementation check
1485     the build type. This in turn would require doing the same for CUDA and thus
1486     potentially break backward compatibility
1487
1488  ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility,
1489     but would require most (if not all) callers to update the call to use
1490     "cuda_or_rocm_only" instead of "cuda_only"
1491
1492  Returns:
1493    True if a GPU device of the requested kind is available.
1494  """
1495  try:
1496    for local_device in device_lib.list_local_devices():
1497      if local_device.device_type == "GPU":
1498        gpu_info = gpu_util.compute_capability_from_device_desc(local_device)
1499        cc = gpu_info.compute_capability or (0, 0)
1500        if not min_cuda_compute_capability or cc >= min_cuda_compute_capability:
1501          return True
1502      if local_device.device_type == "SYCL" and not cuda_only:
1503        return True
1504    return False
1505  except errors_impl.NotFoundError as e:
1506    if not all(x in str(e) for x in ["CUDA", "not find"]):
1507      raise e
1508    else:
1509      logging.error(str(e))
1510      return False
1511
1512
1513@contextlib.contextmanager
1514def device(use_gpu):
1515  """Uses gpu when requested and available."""
1516  if use_gpu and is_gpu_available():
1517    dev = "/device:GPU:0"
1518  else:
1519    dev = "/device:CPU:0"
1520  with ops.device(dev):
1521    yield
1522
1523
1524@contextlib.contextmanager
1525def use_gpu():
1526  """Uses gpu when requested and available."""
1527  with device(use_gpu=True):
1528    yield
1529
1530
1531@contextlib.contextmanager
1532def force_gpu():
1533  """Force the gpu to be used."""
1534  with ops.device("/device:GPU:0"):
1535    yield
1536
1537
1538@contextlib.contextmanager
1539def force_cpu():
1540  """Force the cpu to be used."""
1541  with ops.device("/device:CPU:0"):
1542    yield
1543
1544
1545class CapturedWrites(object):
1546  """A utility class to load the captured writes made to a stream."""
1547
1548  def __init__(self, capture_location):
1549    self.capture_location = capture_location
1550
1551  def contents(self):
1552    """Get the captured writes as a single string."""
1553    with open(self.capture_location) as tmp_file:
1554      output_data = "".join(tmp_file.readlines())
1555    return output_data
1556
1557
1558class FakeEagerSession(object):
1559  """Fake session so tests that conditionally use placeholders can use eager.
1560
1561  There are a number of tests that conditionally use placeholders for shape
1562  inference. The pattern is demonstrated here:
1563
1564  ```python
1565  with self.cached_session() as sess:
1566    if static_shape:
1567      y = math_ops.matmul(x, ...)
1568      feed_dict = {}
1569    else:
1570      x_ph = array_ops.placeholder(...)
1571      y = math_ops.matmul(x_ph, ...)
1572      feed_dict = {x_ph: x}
1573    val = sess.run(y, feed_dict=feed_dict)
1574  ```
1575
1576  Since the feed_dict is empty when not using placeholders we should be able to
1577  call self.evaluate(), however this requires rewriting the test case.
1578  This class should be considered a stop-gap solution to get tests running with
1579  eager with minimal changes to the actual test.
1580  """
1581
1582  def __init__(self, test_case):
1583    self._test_case = test_case
1584
1585  def run(self, fetches, *args, **kwargs):
1586    """Evalaute `fetches`.
1587
1588    Fail if additional args are specified.
1589
1590    Args:
1591      fetches: A Tensor or a nested list/tuple of Tensors.
1592      *args: Positional arguments
1593      **kwargs: Keyword arguments
1594
1595    Raises:
1596      RuntimeError: If args or kwargs are specified.
1597
1598    Returns:
1599      Tensors as numpy values.
1600    """
1601    feed_dict = kwargs.pop("feed_dict", {})
1602    if feed_dict:
1603      raise RuntimeError(
1604          "feed_dict is not supported when eager execution is enabled "
1605          "(in this case, sess.run(t) is shorthand for t.numpy()")
1606
1607    if args or kwargs:
1608      raise RuntimeError(
1609          "Optional args are not supported when eager execution is enabled "
1610          "(in this case, sess.run(t) is shorthand for t.numpy()")
1611
1612    return self._test_case.evaluate(fetches)
1613
1614
1615class ErrorLoggingSession(session.Session):
1616  """Wrapper around a Session that logs errors in run()."""
1617
1618  def run(self, *args, **kwargs):
1619    try:
1620      return super(ErrorLoggingSession, self).run(*args, **kwargs)
1621    except Exception as e:  # pylint: disable=broad-except
1622      # Note: disable the logging for OutOfRangeError, which makes the output
1623      # of tf.data tests hard to read, because OutOfRangeError is used as the
1624      # signal completion
1625      if not isinstance(e, errors.OutOfRangeError):
1626        logging.error(str(e))
1627      raise
1628
1629
1630def disable_cudnn_autotune(func):
1631  """Disable autotuning during the call to this function.
1632
1633  Some tests want to base assertions on a graph being isomorphic with a copy.
1634  To ensure this, this decorator disables autotuning.
1635
1636  Args:
1637    func: Function to run with CuDNN autotuning turned off.
1638
1639  Returns:
1640    Decorated function.
1641  """
1642
1643  def decorator(f):
1644
1645    def decorated(self, *args, **kwargs):
1646      original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE")
1647      os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false"
1648      original_xla_flags = os.environ.get("XLA_FLAGS")
1649      new_xla_flags = "--xla_gpu_autotune_level=0"
1650      if original_xla_flags:
1651        new_xla_flags = original_xla_flags + " " + new_xla_flags
1652      os.environ["XLA_FLAGS"] = new_xla_flags
1653
1654      result = f(self, *args, **kwargs)
1655
1656      if (original_tf_cudnn_use_autotune is None):
1657        del os.environ["TF_CUDNN_USE_AUTOTUNE"]
1658      else:
1659        os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune
1660      if (original_xla_flags is None):
1661        del os.environ["XLA_FLAGS"]
1662      else:
1663        os.environ["XLA_FLAGS"] = original_xla_flags
1664
1665      return result
1666
1667    return decorated
1668
1669  if func is not None:
1670    return decorator(func)
1671
1672  return decorator
1673
1674
1675# The description is just for documentation purposes.
1676def enable_tf_xla_constant_folding(description):
1677
1678  if not isinstance(description, str):
1679    raise ValueError("'description' should be string, got {}".format(
1680        type(description)))
1681
1682  def enable_tf_xla_constant_folding_impl(func):
1683    """Enable constant folding during the call to this function.
1684
1685    Some tests fail without constant folding.
1686
1687    Args:
1688      func: Function to run with constant folding turned on.
1689
1690    Returns:
1691      Decorated function.
1692    """
1693
1694    def decorator(f):
1695
1696      def decorated(self, *args, **kwargs):
1697        original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
1698        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
1699        result = f(self, *args, **kwargs)
1700        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
1701        return result
1702
1703      return decorated
1704
1705    if func is not None:
1706      return decorator(func)
1707
1708    return decorator
1709
1710  return enable_tf_xla_constant_folding_impl
1711
1712
1713# The description is just for documentation purposes.
1714def disable_xla(description):
1715
1716  def disable_xla_impl(func):
1717    """Execute the test method only if xla is not enabled."""
1718
1719    def decorator(func):
1720
1721      def decorated(self, *args, **kwargs):
1722        if is_xla_enabled():
1723          return
1724        else:
1725          return func(self, *args, **kwargs)
1726
1727      return decorated
1728
1729    if func is not None:
1730      return decorator(func)
1731
1732    return decorator
1733
1734  return disable_xla_impl
1735
1736
1737def for_all_test_methods(decorator, *args, **kwargs):
1738  """Generate class-level decorator from given method-level decorator.
1739
1740  It is expected for the given decorator to take some arguments and return
1741  a method that is then called on the test method to produce a decorated
1742  method.
1743
1744  Args:
1745    decorator: The decorator to apply.
1746    *args: Positional arguments
1747    **kwargs: Keyword arguments
1748  Returns: Function that will decorate a given classes test methods with the
1749    decorator.
1750  """
1751
1752  def all_test_methods_impl(cls):
1753    """Apply decorator to all test methods in class."""
1754    for name in dir(cls):
1755      value = getattr(cls, name)
1756      if callable(value) and name.startswith(
1757          "test") and (name != "test_session"):
1758        setattr(cls, name, decorator(*args, **kwargs)(value))
1759    return cls
1760
1761  return all_test_methods_impl
1762
1763
1764# The description is just for documentation purposes.
1765def no_xla_auto_jit(description):  # pylint: disable=unused-argument
1766
1767  def no_xla_auto_jit_impl(func):
1768    """This test is not intended to be run with XLA auto jit enabled."""
1769
1770    def decorator(func):
1771
1772      def decorated(self, *args, **kwargs):
1773        if is_xla_enabled():
1774          # Skip test if using XLA is forced.
1775          return
1776        else:
1777          return func(self, *args, **kwargs)
1778
1779      return decorated
1780
1781    if func is not None:
1782      return decorator(func)
1783
1784    return decorator
1785
1786  return no_xla_auto_jit_impl
1787
1788
1789# The description is just for documentation purposes.
1790def xla_allow_fallback(description):  # pylint: disable=unused-argument
1791
1792  def xla_allow_fallback_impl(func):
1793    """Allow fallback to TF even though testing xla."""
1794
1795    def decorator(func):
1796
1797      def decorated(self, *args, **kwargs):
1798        if is_xla_enabled():
1799          # Update the global XLABuildOpsPassFlags to enable lazy compilation,
1800          # which allows the compiler to fall back to TF classic. Remember the
1801          # old value so that we can reset it.
1802          old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
1803          result = func(self, *args, **kwargs)
1804          pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
1805          return result
1806        else:
1807          return func(self, *args, **kwargs)
1808
1809      return decorated
1810
1811    if func is not None:
1812      return decorator(func)
1813
1814    return decorator
1815
1816  return xla_allow_fallback_impl
1817
1818
1819class EagerSessionWarner(object):
1820
1821  def __getattr__(self, attr):
1822    raise AttributeError(
1823        "Trying to access properties or call methods on the result of "
1824        "self.session(), self.cached_session(), etc while eager execution "
1825        "is enabled. If you're porting this test case to TF 2.0, either "
1826        "adapt the test to work with eager execution or insert a call to "
1827        "tf.disable_eager_execution() in the main() function of this test "
1828        "file.")
1829
1830
1831@tf_export("test.TestCase")
1832class TensorFlowTestCase(googletest.TestCase):
1833  """Base class for tests that need to test TensorFlow."""
1834
1835  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
1836    super(TensorFlowTestCase, self).__init__(methodName)
1837    if is_xla_enabled():
1838      pywrap_tf_session.TF_SetXlaAutoJitMode("2")
1839      pywrap_tf_session.TF_SetXlaMinClusterSize(1)
1840      pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
1841      pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
1842      # Constant folding secretly runs code on TF:Classic CPU, so we also
1843      # disable it here.
1844      pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
1845
1846    self._threads = []
1847    self._tempdir = None
1848    self._cached_session = None
1849
1850  def setUp(self):
1851    self._ClearCachedSession()
1852    random.seed(random_seed.DEFAULT_GRAPH_SEED)
1853    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
1854    # Note: The following line is necessary because some test methods may error
1855    # out from within nested graph contexts (e.g., via assertRaises and
1856    # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
1857    # under certain versions of Python. That would cause
1858    # ops.reset_default_graph() to throw an exception if the stack were not
1859    # cleared first.
1860    ops._default_graph_stack.reset()  # pylint: disable=protected-access
1861    ops.reset_default_graph()
1862    random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
1863    # Reset summary writer in case another test used set_as_default() with their
1864    # summary writer.
1865    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
1866    summary_state.writer = None
1867
1868    # Avoiding calling setUp() for the poorly named test_session method.
1869    if self.id().endswith(".test_session"):
1870      self.skipTest("Not a test.")
1871
1872  def tearDown(self):
1873    for thread in self._threads:
1874      thread.check_termination()
1875
1876    self._ClearCachedSession()
1877
1878  def _ClearCachedSession(self):
1879    if self._cached_session is not None:
1880      self._cached_session.close()
1881      self._cached_session = None
1882
1883  def get_temp_dir(self):
1884    """Returns a unique temporary directory for the test to use.
1885
1886    If you call this method multiple times during in a test, it will return the
1887    same folder. However, across different runs the directories will be
1888    different. This will ensure that across different runs tests will not be
1889    able to pollute each others environment.
1890    If you need multiple unique directories within a single test, you should
1891    use tempfile.mkdtemp as follows:
1892      tempfile.mkdtemp(dir=self.get_temp_dir()):
1893
1894    Returns:
1895      string, the path to the unique temporary directory created for this test.
1896    """
1897    if not self._tempdir:
1898      self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
1899    return self._tempdir
1900
1901  @contextlib.contextmanager
1902  def captureWritesToStream(self, stream):
1903    """A context manager that captures the writes to a given stream.
1904
1905    This context manager captures all writes to a given stream inside of a
1906    `CapturedWrites` object. When this context manager is created, it yields
1907    the `CapturedWrites` object. The captured contents can be accessed  by
1908    calling `.contents()` on the `CapturedWrites`.
1909
1910    For this function to work, the stream must have a file descriptor that
1911    can be modified using `os.dup` and `os.dup2`, and the stream must support
1912    a `.flush()` method. The default python sys.stdout and sys.stderr are
1913    examples of this. Note that this does not work in Colab or Jupyter
1914    notebooks, because those use alternate stdout streams.
1915
1916    Example:
1917    ```python
1918    class MyOperatorTest(test_util.TensorFlowTestCase):
1919      def testMyOperator(self):
1920        input = [1.0, 2.0, 3.0, 4.0, 5.0]
1921        with self.captureWritesToStream(sys.stdout) as captured:
1922          result = MyOperator(input).eval()
1923        self.assertStartsWith(captured.contents(), "This was printed.")
1924    ```
1925
1926    Args:
1927      stream: The stream whose writes should be captured. This stream must have
1928        a file descriptor, support writing via using that file descriptor, and
1929        must have a `.flush()` method.
1930
1931    Yields:
1932      A `CapturedWrites` object that contains all writes to the specified stream
1933      made during this context.
1934    """
1935    stream.flush()
1936    fd = stream.fileno()
1937    tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
1938    tmp_file = open(tmp_file_path, "w")
1939    orig_fd = os.dup(fd)
1940    os.dup2(tmp_file.fileno(), fd)
1941    try:
1942      yield CapturedWrites(tmp_file_path)
1943    finally:
1944      tmp_file.close()
1945      os.dup2(orig_fd, fd)
1946
1947  def _AssertProtoEquals(self, a, b, msg=None):
1948    """Asserts that a and b are the same proto.
1949
1950    Uses ProtoEq() first, as it returns correct results
1951    for floating point attributes, and then use assertProtoEqual()
1952    in case of failure as it provides good error messages.
1953
1954    Args:
1955      a: a proto.
1956      b: another proto.
1957      msg: Optional message to report on failure.
1958    """
1959    if not compare.ProtoEq(a, b):
1960      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
1961
1962  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
1963    """Asserts that message is same as parsed expected_message_ascii.
1964
1965    Creates another prototype of message, reads the ascii message into it and
1966    then compares them using self._AssertProtoEqual().
1967
1968    Args:
1969      expected_message_maybe_ascii: proto message in original or ascii form.
1970      message: the message to validate.
1971      msg: Optional message to report on failure.
1972    """
1973    msg = msg if msg else ""
1974    if isinstance(expected_message_maybe_ascii, type(message)):
1975      expected_message = expected_message_maybe_ascii
1976      self._AssertProtoEquals(expected_message, message)
1977    elif isinstance(expected_message_maybe_ascii, str):
1978      expected_message = type(message)()
1979      text_format.Merge(
1980          expected_message_maybe_ascii,
1981          expected_message,
1982          descriptor_pool=descriptor_pool.Default())
1983      self._AssertProtoEquals(expected_message, message, msg=msg)
1984    else:
1985      assert False, ("Can't compare protos of type %s and %s. %s" %
1986                     (type(expected_message_maybe_ascii), type(message), msg))
1987
1988  def assertProtoEqualsVersion(
1989      self,
1990      expected,
1991      actual,
1992      producer=versions.GRAPH_DEF_VERSION,
1993      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
1994      msg=None):
1995    expected = "versions { producer: %d min_consumer: %d };\n%s" % (
1996        producer, min_consumer, expected)
1997    self.assertProtoEquals(expected, actual, msg=msg)
1998
1999  def assertStartsWith(self, actual, expected_start, msg=None):
2000    """Assert that actual.startswith(expected_start) is True.
2001
2002    Args:
2003      actual: str
2004      expected_start: str
2005      msg: Optional message to report on failure.
2006    """
2007    if not actual.startswith(expected_start):
2008      fail_msg = "%r does not start with %r" % (actual, expected_start)
2009      fail_msg += " : %r" % (msg) if msg else ""
2010      self.fail(fail_msg)
2011
2012  def _eval_tensor(self, tensor):
2013    if tensor is None:
2014      return None
2015    elif callable(tensor):
2016      return self._eval_helper(tensor())
2017    else:
2018      try:
2019        if sparse_tensor.is_sparse(tensor):
2020          return sparse_tensor.SparseTensorValue(tensor.indices.numpy(),
2021                                                 tensor.values.numpy(),
2022                                                 tensor.dense_shape.numpy())
2023        elif ragged_tensor.is_ragged(tensor):
2024          return ragged_tensor_value.RaggedTensorValue(
2025              self._eval_tensor(tensor.values),
2026              self._eval_tensor(tensor.row_splits))
2027        elif isinstance(tensor, ops.IndexedSlices):
2028          return ops.IndexedSlicesValue(
2029              values=tensor.values.numpy(),
2030              indices=tensor.indices.numpy(),
2031              dense_shape=tensor.dense_shape.numpy())
2032        return tensor.numpy()
2033      except AttributeError as e:
2034        six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e)
2035
2036  def _eval_helper(self, tensors):
2037    if tensors is None:
2038      return None
2039    return nest.map_structure(self._eval_tensor, tensors)
2040
2041  def evaluate(self, tensors):
2042    """Evaluates tensors and returns numpy values.
2043
2044    Args:
2045      tensors: A Tensor or a nested list/tuple of Tensors.
2046
2047    Returns:
2048      tensors numpy values.
2049    """
2050    if context.executing_eagerly():
2051      return self._eval_helper(tensors)
2052    else:
2053      sess = ops.get_default_session()
2054      if sess is None:
2055        with self.test_session() as sess:
2056          return sess.run(tensors)
2057      else:
2058        return sess.run(tensors)
2059
2060  # pylint: disable=g-doc-return-or-yield
2061  @contextlib.contextmanager
2062  def session(self, graph=None, config=None, use_gpu=False, force_gpu=False):
2063    """A context manager for a TensorFlow Session for use in executing tests.
2064
2065    Note that this will set this session and the graph as global defaults.
2066
2067    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2068    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2069    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2070    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2071    the CPU.
2072
2073    Example:
2074
2075    ``` python
2076    class MyOperatorTest(test_util.TensorFlowTestCase):
2077      def testMyOperator(self):
2078        with self.session(use_gpu=True):
2079          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2080          result = MyOperator(valid_input).eval()
2081          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2082          invalid_input = [-1.0, 2.0, 7.0]
2083          with self.assertRaisesOpError("negative input not supported"):
2084            MyOperator(invalid_input).eval()
2085    ```
2086
2087    Args:
2088      graph: Optional graph to use during the returned session.
2089      config: An optional config_pb2.ConfigProto to use to configure the
2090        session.
2091      use_gpu: If True, attempt to run as many ops as possible on GPU.
2092      force_gpu: If True, pin all ops to `/device:GPU:0`.
2093
2094    Yields:
2095      A Session object that should be used as a context manager to surround
2096      the graph building and execution code in a test case.
2097    """
2098    if context.executing_eagerly():
2099      yield EagerSessionWarner()
2100    else:
2101      with self._create_session(graph, config, force_gpu) as sess:
2102        with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
2103          yield sess
2104
2105  @contextlib.contextmanager
2106  def cached_session(self,
2107                     graph=None,
2108                     config=None,
2109                     use_gpu=False,
2110                     force_gpu=False):
2111    """Returns a TensorFlow Session for use in executing tests.
2112
2113    This method behaves differently than self.session(): for performance reasons
2114    `cached_session` will by default reuse the same session within the same
2115    test. The session returned by this function will only be closed at the end
2116    of the test (in the TearDown function).
2117
2118    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2119    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2120    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2121    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2122    the CPU.
2123
2124    Example:
2125    ```python
2126    class MyOperatorTest(test_util.TensorFlowTestCase):
2127      def testMyOperator(self):
2128        with self.cached_session(use_gpu=True) as sess:
2129          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2130          result = MyOperator(valid_input).eval()
2131          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2132          invalid_input = [-1.0, 2.0, 7.0]
2133          with self.assertRaisesOpError("negative input not supported"):
2134            MyOperator(invalid_input).eval()
2135    ```
2136
2137    Args:
2138      graph: Optional graph to use during the returned session.
2139      config: An optional config_pb2.ConfigProto to use to configure the
2140        session.
2141      use_gpu: If True, attempt to run as many ops as possible on GPU.
2142      force_gpu: If True, pin all ops to `/device:GPU:0`.
2143
2144    Yields:
2145      A Session object that should be used as a context manager to surround
2146      the graph building and execution code in a test case.
2147    """
2148    if context.executing_eagerly():
2149      yield FakeEagerSession(self)
2150    else:
2151      sess = self._get_cached_session(
2152          graph, config, force_gpu, crash_if_inconsistent_args=True)
2153      with self._constrain_devices_and_set_default(sess, use_gpu,
2154                                                   force_gpu) as cached:
2155        yield cached
2156
2157  @contextlib.contextmanager
2158  @deprecation.deprecated(None, "Use `self.session()` or "
2159                          "`self.cached_session()` instead.")
2160  def test_session(self,
2161                   graph=None,
2162                   config=None,
2163                   use_gpu=False,
2164                   force_gpu=False):
2165    """Use cached_session instead."""
2166    if self.id().endswith(".test_session"):
2167      self.skipTest(
2168          "Tests that have the name \"test_session\" are automatically skipped "
2169          "by TensorFlow test fixture, as the name is reserved for creating "
2170          "sessions within tests. Please rename your test if you have a test "
2171          "with this name.")
2172    if context.executing_eagerly():
2173      yield None
2174    else:
2175      if graph is None:
2176        sess = self._get_cached_session(
2177            graph, config, force_gpu, crash_if_inconsistent_args=False)
2178        with self._constrain_devices_and_set_default(sess, use_gpu,
2179                                                     force_gpu) as cached:
2180          yield cached
2181      else:
2182        with self.session(graph, config, use_gpu, force_gpu) as sess:
2183          yield sess
2184
2185  # pylint: enable=g-doc-return-or-yield
2186
2187  class _CheckedThread(object):
2188    """A wrapper class for Thread that asserts successful completion.
2189
2190    This class should be created using the TensorFlowTestCase.checkedThread()
2191    method.
2192    """
2193
2194    def __init__(self, testcase, target, args=None, kwargs=None):
2195      """Constructs a new instance of _CheckedThread.
2196
2197      Args:
2198        testcase: The TensorFlowTestCase for which this thread is being created.
2199        target: A callable object representing the code to be executed in the
2200          thread.
2201        args: A tuple of positional arguments that will be passed to target.
2202        kwargs: A dictionary of keyword arguments that will be passed to target.
2203      """
2204      self._testcase = testcase
2205      self._target = target
2206      self._args = () if args is None else args
2207      self._kwargs = {} if kwargs is None else kwargs
2208      self._thread = threading.Thread(target=self._protected_run)
2209      self._exception = None
2210
2211      self._is_thread_joined = False
2212
2213    def _protected_run(self):
2214      """Target for the wrapper thread. Sets self._exception on failure."""
2215      try:
2216        self._target(*self._args, **self._kwargs)
2217      except Exception as e:  # pylint: disable=broad-except
2218        self._exception = e
2219
2220    def start(self):
2221      """Starts the thread's activity.
2222
2223      This must be called at most once per _CheckedThread object. It arranges
2224      for the object's target to be invoked in a separate thread of control.
2225      """
2226      self._thread.start()
2227
2228    def join(self):
2229      """Blocks until the thread terminates.
2230
2231      Raises:
2232        self._testcase.failureException: If the thread terminates with due to
2233          an exception.
2234      """
2235      self._is_thread_joined = True
2236      self._thread.join()
2237      if self._exception is not None:
2238        self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
2239
2240    def is_alive(self):
2241      """Returns whether the thread is alive.
2242
2243      This method returns True just before the run() method starts
2244      until just after the run() method terminates.
2245
2246      Returns:
2247        True if the thread is alive, otherwise False.
2248      """
2249      return self._thread.is_alive()
2250
2251    def check_termination(self):
2252      """Returns whether the checked thread was properly used and did terminate.
2253
2254      Every checked thread should be "join"ed after starting, and before the
2255      test tears down. If it is not joined, it is possible the thread will hang
2256      and cause flaky failures in tests.
2257
2258      Raises:
2259        self._testcase.failureException: If check_termination was called before
2260        thread was joined.
2261
2262        RuntimeError: If the thread is not terminated. This means thread was not
2263        joined with the main thread.
2264      """
2265      if self._is_thread_joined:
2266        if self.is_alive():
2267          raise RuntimeError(
2268              "Thread was not joined with main thread, and is still running "
2269              "when the test finished.")
2270      else:
2271        self._testcase.fail("A checked thread was not joined.")
2272
2273  def checkedThread(self, target, args=None, kwargs=None):
2274    """Returns a Thread wrapper that asserts 'target' completes successfully.
2275
2276    This method should be used to create all threads in test cases, as
2277    otherwise there is a risk that a thread will silently fail, and/or
2278    assertions made in the thread will not be respected.
2279
2280    Args:
2281      target: A callable object to be executed in the thread.
2282      args: The argument tuple for the target invocation. Defaults to ().
2283      kwargs: A dictionary of keyword arguments for the target invocation.
2284        Defaults to {}.
2285
2286    Returns:
2287      A wrapper for threading.Thread that supports start() and join() methods.
2288    """
2289    ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
2290    self._threads.append(ret)
2291    return ret
2292
2293  # pylint: enable=invalid-name
2294  @py_func_if_in_function
2295  def assertNear(self, f1, f2, err, msg=None):
2296    """Asserts that two floats are near each other.
2297
2298    Checks that |f1 - f2| < err and asserts a test failure
2299    if not.
2300
2301    Args:
2302      f1: A float value.
2303      f2: A float value.
2304      err: A float value.
2305      msg: An optional string message to append to the failure message.
2306    """
2307    # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
2308    self.assertTrue(
2309        f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
2310        (f1, f2, err, " (%s)" % msg if msg is not None else ""))
2311
2312  @py_func_if_in_function
2313  def assertArrayNear(self, farray1, farray2, err, msg=None):
2314    """Asserts that two float arrays are near each other.
2315
2316    Checks that for all elements of farray1 and farray2
2317    |f1 - f2| < err.  Asserts a test failure if not.
2318
2319    Args:
2320      farray1: a list of float values.
2321      farray2: a list of float values.
2322      err: a float value.
2323      msg: Optional message to report on failure.
2324    """
2325    self.assertEqual(len(farray1), len(farray2), msg=msg)
2326    for f1, f2 in zip(farray1, farray2):
2327      self.assertNear(float(f1), float(f2), err, msg=msg)
2328
2329  def _NDArrayNear(self, ndarray1, ndarray2, err):
2330    return np.linalg.norm(ndarray1 - ndarray2) < err
2331
2332  @py_func_if_in_function
2333  def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
2334    """Asserts that two numpy arrays have near values.
2335
2336    Args:
2337      ndarray1: a numpy ndarray.
2338      ndarray2: a numpy ndarray.
2339      err: a float. The maximum absolute difference allowed.
2340      msg: Optional message to report on failure.
2341    """
2342    self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
2343
2344  def _GetNdArray(self, a):
2345    # If a is tensor-like then convert it to ndarray
2346    if tensor_util.is_tensor(a):
2347      if isinstance(a, ops._EagerTensorBase):
2348        a = a.numpy()
2349      else:
2350        a = self.evaluate(a)
2351    if not isinstance(a, np.ndarray):
2352      return np.array(a)
2353    return a
2354
2355  def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2356    a = self._GetNdArray(a)
2357    b = self._GetNdArray(b)
2358    # When the array rank is small, print its contents. Numpy array printing is
2359    # implemented using inefficient recursion so prints can cause tests to
2360    # time out.
2361    if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
2362      shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
2363                            "%s.") % (a.shape, b.shape, b)
2364    else:
2365      shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
2366                                                                     b.shape)
2367    self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
2368
2369    msgs = [msg]
2370    if not np.allclose(a, b, rtol=rtol, atol=atol):
2371      # Adds more details to np.testing.assert_allclose.
2372      #
2373      # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
2374      # checks whether two arrays are element-wise equal within a
2375      # tolerance. The relative difference (rtol * abs(b)) and the
2376      # absolute difference atol are added together to compare against
2377      # the absolute difference between a and b.  Here, we want to
2378      # tell user which elements violate such conditions.
2379      cond = np.logical_or(
2380          np.abs(a - b) > atol + rtol * np.abs(b),
2381          np.isnan(a) != np.isnan(b))
2382      if a.ndim:
2383        x = a[np.where(cond)]
2384        y = b[np.where(cond)]
2385        msgs.append("not close where = {}".format(np.where(cond)))
2386      else:
2387        # np.where is broken for scalars
2388        x, y = a, b
2389      msgs.append("not close lhs = {}".format(x))
2390      msgs.append("not close rhs = {}".format(y))
2391      msgs.append("not close dif = {}".format(np.abs(x - y)))
2392      msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
2393      msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
2394      # TODO(xpan): There seems to be a bug:
2395      # tensorflow/compiler/tests:binary_ops_test pass with float32
2396      # nan even though the equal_nan is False by default internally.
2397      np.testing.assert_allclose(
2398          a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
2399
2400  def _assertAllCloseRecursive(self,
2401                               a,
2402                               b,
2403                               rtol=1e-6,
2404                               atol=1e-6,
2405                               path=None,
2406                               msg=None):
2407    path = path or []
2408    path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "")
2409    msg = msg if msg else ""
2410
2411    # Check if a and/or b are namedtuples.
2412    if hasattr(a, "_asdict"):
2413      a = a._asdict()
2414    if hasattr(b, "_asdict"):
2415      b = b._asdict()
2416    a_is_dict = isinstance(a, collections_abc.Mapping)
2417    if a_is_dict != isinstance(b, collections_abc.Mapping):
2418      raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
2419                       (path_str, path_str, msg))
2420    if a_is_dict:
2421      self.assertItemsEqual(
2422          a.keys(),
2423          b.keys(),
2424          msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
2425          (path_str, a.keys(), path_str, b.keys(), msg))
2426      for k in a:
2427        path.append(k)
2428        self._assertAllCloseRecursive(
2429            a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
2430        del path[-1]
2431    elif isinstance(a, (list, tuple)):
2432      # Try to directly compare a, b as ndarrays; if not work, then traverse
2433      # through the sequence, which is more expensive.
2434      try:
2435        a_as_ndarray = self._GetNdArray(a)
2436        b_as_ndarray = self._GetNdArray(b)
2437        self._assertArrayLikeAllClose(
2438            a_as_ndarray,
2439            b_as_ndarray,
2440            rtol=rtol,
2441            atol=atol,
2442            msg="Mismatched value: a%s is different from b%s. %s" %
2443            (path_str, path_str, msg))
2444      except (ValueError, TypeError) as e:
2445        if len(a) != len(b):
2446          raise ValueError(
2447              "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
2448              (path_str, len(a), path_str, len(b), msg))
2449        for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
2450          path.append(str(idx))
2451          self._assertAllCloseRecursive(
2452              a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
2453          del path[-1]
2454    # a and b are ndarray like objects
2455    else:
2456      try:
2457        self._assertArrayLikeAllClose(
2458            a,
2459            b,
2460            rtol=rtol,
2461            atol=atol,
2462            msg=("Mismatched value: a%s is different from b%s. %s" %
2463                 (path_str, path_str, msg)))
2464      except TypeError as e:
2465        msg = ("Error: a%s has %s, but b%s has %s. %s" %
2466               (path_str, type(a), path_str, type(b), msg))
2467        e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
2468        raise
2469
2470  @py_func_if_in_function
2471  def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2472    """Asserts that two structures of numpy arrays or Tensors, have near values.
2473
2474    `a` and `b` can be arbitrarily nested structures. A layer of a nested
2475    structure can be a `dict`, `namedtuple`, `tuple` or `list`.
2476
2477    Args:
2478      a: The expected numpy `ndarray`, or anything that can be converted into a
2479        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2480        structure of these.
2481      b: The actual numpy `ndarray`, or anything that can be converted into a
2482        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2483        structure of these.
2484      rtol: relative tolerance.
2485      atol: absolute tolerance.
2486      msg: Optional message to report on failure.
2487
2488    Raises:
2489      ValueError: if only one of `a[p]` and `b[p]` is a dict or
2490          `a[p]` and `b[p]` have different length, where `[p]` denotes a path
2491          to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
2492          `[p] = [1]['d']`, then `a[p] = (6, 7)`.
2493    """
2494    if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
2495      return self._assertRaggedClose(a, b, rtol, atol, msg)
2496    self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
2497
2498  @py_func_if_in_function
2499  def assertAllCloseAccordingToType(self,
2500                                    a,
2501                                    b,
2502                                    rtol=1e-6,
2503                                    atol=1e-6,
2504                                    float_rtol=1e-6,
2505                                    float_atol=1e-6,
2506                                    half_rtol=1e-3,
2507                                    half_atol=1e-3,
2508                                    bfloat16_rtol=1e-2,
2509                                    bfloat16_atol=1e-2,
2510                                    msg=None):
2511    """Like assertAllClose, but also suitable for comparing fp16 arrays.
2512
2513    In particular, the tolerance is reduced to 1e-3 if at least
2514    one of the arguments is of type float16.
2515
2516    Args:
2517      a: the expected numpy ndarray or anything can be converted to one.
2518      b: the actual numpy ndarray or anything can be converted to one.
2519      rtol: relative tolerance.
2520      atol: absolute tolerance.
2521      float_rtol: relative tolerance for float32.
2522      float_atol: absolute tolerance for float32.
2523      half_rtol: relative tolerance for float16.
2524      half_atol: absolute tolerance for float16.
2525      bfloat16_rtol: relative tolerance for bfloat16.
2526      bfloat16_atol: absolute tolerance for bfloat16.
2527      msg: Optional message to report on failure.
2528    """
2529    a = self._GetNdArray(a)
2530    b = self._GetNdArray(b)
2531    # types with lower tol are put later to overwrite previous ones.
2532    if (a.dtype == np.float32 or b.dtype == np.float32 or
2533        a.dtype == np.complex64 or b.dtype == np.complex64):
2534      rtol = max(rtol, float_rtol)
2535      atol = max(atol, float_atol)
2536    if a.dtype == np.float16 or b.dtype == np.float16:
2537      rtol = max(rtol, half_rtol)
2538      atol = max(atol, half_atol)
2539    if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
2540        b.dtype == dtypes.bfloat16.as_numpy_dtype):
2541      rtol = max(rtol, bfloat16_rtol)
2542      atol = max(atol, bfloat16_atol)
2543
2544    self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
2545
2546  @py_func_if_in_function
2547  def assertNotAllClose(self, a, b, **kwargs):
2548    """Assert that two numpy arrays, or Tensors, do not have near values.
2549
2550    Args:
2551      a: the first value to compare.
2552      b: the second value to compare.
2553      **kwargs: additional keyword arguments to be passed to the underlying
2554        `assertAllClose` call.
2555
2556    Raises:
2557      AssertionError: If `a` and `b` are unexpectedly close at all elements.
2558    """
2559    try:
2560      self.assertAllClose(a, b, **kwargs)
2561    except AssertionError:
2562      return
2563    raise AssertionError("The two values are close at all elements")
2564
2565  @py_func_if_in_function
2566  def assertAllEqual(self, a, b, msg=None):
2567    """Asserts that two numpy arrays or Tensors have the same values.
2568
2569    Args:
2570      a: the expected numpy ndarray or anything can be converted to one.
2571      b: the actual numpy ndarray or anything can be converted to one.
2572      msg: Optional message to report on failure.
2573    """
2574    if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
2575      return self._assertRaggedEqual(a, b, msg)
2576    msg = msg if msg else ""
2577    a = self._GetNdArray(a)
2578    b = self._GetNdArray(b)
2579    # Arbitrary bounds so that we don't print giant tensors.
2580    if (b.ndim <= 3 or b.size < 500):
2581      self.assertEqual(
2582          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2583          " Contents: %s. \n%s." % (a.shape, b.shape, b, msg))
2584    else:
2585      self.assertEqual(
2586          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2587          " %s" % (a.shape, b.shape, msg))
2588
2589    same = (a == b)
2590
2591    if (a.dtype in [
2592        np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
2593    ]):
2594      same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
2595    msgs = [msg]
2596    if not np.all(same):
2597      # Adds more details to np.testing.assert_array_equal.
2598      diff = np.logical_not(same)
2599      if a.ndim:
2600        x = a[np.where(diff)]
2601        y = b[np.where(diff)]
2602        msgs.append("not equal where = {}".format(np.where(diff)))
2603      else:
2604        # np.where is broken for scalars
2605        x, y = a, b
2606      msgs.append("not equal lhs = {}".format(x))
2607      msgs.append("not equal rhs = {}".format(y))
2608      # With Python 3, we need to make sure the dtype matches between a and b.
2609      b = b.astype(a.dtype)
2610      np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
2611
2612  @py_func_if_in_function
2613  def assertNotAllEqual(self, a, b, msg=None):
2614    """Asserts that two numpy arrays or Tensors do not have the same values.
2615
2616    Args:
2617      a: the expected numpy ndarray or anything can be converted to one.
2618      b: the actual numpy ndarray or anything can be converted to one.
2619      msg: Optional message to report on failure.
2620    """
2621    try:
2622      self.assertAllEqual(a, b, msg)
2623    except AssertionError:
2624      return
2625    raise AssertionError("The two values are equal at all elements")
2626
2627  @py_func_if_in_function
2628  def assertAllGreater(self, a, comparison_target):
2629    """Assert element values are all greater than a target value.
2630
2631    Args:
2632      a: The numpy `ndarray`, or anything that can be converted into a numpy
2633        `ndarray` (including Tensor).
2634      comparison_target: The target value of comparison.
2635    """
2636    a = self._GetNdArray(a)
2637    self.assertGreater(np.min(a), comparison_target)
2638
2639  @py_func_if_in_function
2640  def assertAllLess(self, a, comparison_target):
2641    """Assert element values are all less than a target value.
2642
2643    Args:
2644      a: The numpy `ndarray`, or anything that can be converted into a numpy
2645        `ndarray` (including Tensor).
2646      comparison_target: The target value of comparison.
2647    """
2648    a = self._GetNdArray(a)
2649    self.assertLess(np.max(a), comparison_target)
2650
2651  @py_func_if_in_function
2652  def assertAllGreaterEqual(self, a, comparison_target):
2653    """Assert element values are all greater than or equal to a target value.
2654
2655    Args:
2656      a: The numpy `ndarray`, or anything that can be converted into a numpy
2657        `ndarray` (including Tensor).
2658      comparison_target: The target value of comparison.
2659    """
2660    a = self._GetNdArray(a)
2661    self.assertGreaterEqual(np.min(a), comparison_target)
2662
2663  @py_func_if_in_function
2664  def assertAllLessEqual(self, a, comparison_target):
2665    """Assert element values are all less than or equal to a target value.
2666
2667    Args:
2668      a: The numpy `ndarray`, or anything that can be converted into a numpy
2669        `ndarray` (including Tensor).
2670      comparison_target: The target value of comparison.
2671    """
2672    a = self._GetNdArray(a)
2673    self.assertLessEqual(np.max(a), comparison_target)
2674
2675  def _format_subscripts(self, subscripts, value, limit=10, indent=2):
2676    """Generate a summary of ndarray subscripts as a list of str.
2677
2678    If limit == N, this method will print up to the first N subscripts on
2679    separate
2680    lines. A line of ellipses (...) will be appended at the end if the number of
2681    subscripts exceeds N.
2682
2683    Args:
2684      subscripts: The tensor (np.ndarray) subscripts, of the same format as
2685        np.where()'s return value, i.e., a tuple of arrays with each array
2686        corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
2687      value: (np.ndarray) value of the tensor.
2688      limit: (int) The maximum number of indices to print.
2689      indent: (int) Number of characters to indent at the beginning of each
2690        line.
2691
2692    Returns:
2693      (list of str) the multi-line representation of the subscripts and values,
2694        potentially with omission at the end.
2695    """
2696    lines = []
2697    subscripts = np.transpose(subscripts)
2698    prefix = " " * indent
2699    for subscript in itertools.islice(subscripts, limit):
2700      lines.append(prefix + str(subscript) + " : " +
2701                   str(value[tuple(subscript)]))
2702    if len(subscripts) > limit:
2703      lines.append(prefix + "...")
2704    return lines
2705
2706  @py_func_if_in_function
2707  def assertAllInRange(self,
2708                       target,
2709                       lower_bound,
2710                       upper_bound,
2711                       open_lower_bound=False,
2712                       open_upper_bound=False):
2713    """Assert that elements in a Tensor are all in a given range.
2714
2715    Args:
2716      target: The numpy `ndarray`, or anything that can be converted into a
2717        numpy `ndarray` (including Tensor).
2718      lower_bound: lower bound of the range
2719      upper_bound: upper bound of the range
2720      open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
2721        than the default >=)
2722      open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
2723        than the default <=)
2724
2725    Raises:
2726      AssertionError:
2727        if the value tensor does not have an ordered numeric type (float* or
2728          int*), or
2729        if there are nan values, or
2730        if any of the elements do not fall in the specified range.
2731    """
2732    target = self._GetNdArray(target)
2733    if not (np.issubdtype(target.dtype, np.floating) or
2734            np.issubdtype(target.dtype, np.integer)):
2735      raise AssertionError(
2736          "The value of %s does not have an ordered numeric type, instead it "
2737          "has type: %s" % (target, target.dtype))
2738
2739    nan_subscripts = np.where(np.isnan(target))
2740    if np.size(nan_subscripts):
2741      raise AssertionError(
2742          "%d of the %d element(s) are NaN. "
2743          "Subscripts(s) and value(s) of the NaN element(s):\n" %
2744          (len(nan_subscripts[0]), np.size(target)) +
2745          "\n".join(self._format_subscripts(nan_subscripts, target)))
2746
2747    range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
2748                 str(upper_bound) + (")" if open_upper_bound else "]"))
2749
2750    violations = (
2751        np.less_equal(target, lower_bound) if open_lower_bound else np.less(
2752            target, lower_bound))
2753    violations = np.logical_or(
2754        violations,
2755        np.greater_equal(target, upper_bound)
2756        if open_upper_bound else np.greater(target, upper_bound))
2757    violation_subscripts = np.where(violations)
2758    if np.size(violation_subscripts):
2759      raise AssertionError(
2760          "%d of the %d element(s) are outside the range %s. " %
2761          (len(violation_subscripts[0]), np.size(target), range_str) +
2762          "Subscript(s) and value(s) of the offending elements:\n" +
2763          "\n".join(self._format_subscripts(violation_subscripts, target)))
2764
2765  @py_func_if_in_function
2766  def assertAllInSet(self, target, expected_set):
2767    """Assert that elements of a Tensor are all in a given closed set.
2768
2769    Args:
2770      target: The numpy `ndarray`, or anything that can be converted into a
2771        numpy `ndarray` (including Tensor).
2772      expected_set: (`list`, `tuple` or `set`) The closed set that the elements
2773        of the value of `target` are expected to fall into.
2774
2775    Raises:
2776      AssertionError:
2777        if any of the elements do not fall into `expected_set`.
2778    """
2779    target = self._GetNdArray(target)
2780
2781    # Elements in target that are not in expected_set.
2782    diff = np.setdiff1d(target.flatten(), list(expected_set))
2783    if np.size(diff):
2784      raise AssertionError("%d unique element(s) are not in the set %s: %s" %
2785                           (np.size(diff), expected_set, diff))
2786
2787  @py_func_if_in_function
2788  def assertDTypeEqual(self, target, expected_dtype):
2789    """Assert ndarray data type is equal to expected.
2790
2791    Args:
2792      target: The numpy `ndarray`, or anything that can be converted into a
2793        numpy `ndarray` (including Tensor).
2794      expected_dtype: Expected data type.
2795    """
2796    target = self._GetNdArray(target)
2797    if not isinstance(target, list):
2798      arrays = [target]
2799    for arr in arrays:
2800      self.assertEqual(arr.dtype, expected_dtype)
2801
2802  # pylint: disable=g-doc-return-or-yield
2803  @contextlib.contextmanager
2804  def assertRaisesWithPredicateMatch(self, exception_type,
2805                                     expected_err_re_or_predicate):
2806    """Returns a context manager to enclose code expected to raise an exception.
2807
2808    If the exception is an OpError, the op stack is also included in the message
2809    predicate search.
2810
2811    Args:
2812      exception_type: The expected type of exception that should be raised.
2813      expected_err_re_or_predicate: If this is callable, it should be a function
2814        of one argument that inspects the passed-in exception and returns True
2815        (success) or False (please fail the test). Otherwise, the error message
2816        is expected to match this regular expression partially.
2817
2818    Returns:
2819      A context manager to surround code that is expected to raise an
2820      exception.
2821    """
2822    if callable(expected_err_re_or_predicate):
2823      predicate = expected_err_re_or_predicate
2824    else:
2825
2826      def predicate(e):
2827        err_str = e.message if isinstance(e, errors.OpError) else str(e)
2828        op = e.op if isinstance(e, errors.OpError) else None
2829        while op is not None:
2830          err_str += "\nCaused by: " + op.name
2831          op = op._original_op  # pylint: disable=protected-access
2832        logging.info("Searching within error strings: '%s' within '%s'",
2833                     expected_err_re_or_predicate, err_str)
2834        return re.search(expected_err_re_or_predicate, err_str)
2835
2836    try:
2837      yield
2838      self.fail(exception_type.__name__ + " not raised")
2839    except Exception as e:  # pylint: disable=broad-except
2840      if not isinstance(e, exception_type) or not predicate(e):
2841        raise AssertionError("Exception of type %s: %s" %
2842                             (str(type(e)), str(e)))
2843
2844  # pylint: enable=g-doc-return-or-yield
2845
2846  def assertRaisesOpError(self, expected_err_re_or_predicate):
2847    return self.assertRaisesWithPredicateMatch(errors.OpError,
2848                                               expected_err_re_or_predicate)
2849
2850  def assertShapeEqual(self, np_array, tf_tensor, msg=None):
2851    """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
2852
2853    Args:
2854      np_array: A Numpy ndarray or Numpy scalar.
2855      tf_tensor: A Tensor.
2856      msg: Optional message to report on failure.
2857
2858    Raises:
2859      TypeError: If the arguments have the wrong type.
2860    """
2861    if not isinstance(np_array, (np.ndarray, np.generic)):
2862      raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
2863    if not isinstance(tf_tensor, ops.Tensor):
2864      raise TypeError("tf_tensor must be a Tensor")
2865    self.assertAllEqual(
2866        np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
2867
2868  def assertDeviceEqual(self, device1, device2, msg=None):
2869    """Asserts that the two given devices are the same.
2870
2871    Args:
2872      device1: A string device name or TensorFlow `DeviceSpec` object.
2873      device2: A string device name or TensorFlow `DeviceSpec` object.
2874      msg: Optional message to report on failure.
2875    """
2876    device1 = pydev.canonical_name(device1)
2877    device2 = pydev.canonical_name(device2)
2878    self.assertEqual(
2879        device1, device2,
2880        "Devices %s and %s are not equal. %s" % (device1, device2, msg))
2881
2882  def _GetPyList(self, a):
2883    """Converts `a` to a nested python list."""
2884    if isinstance(a, ragged_tensor.RaggedTensor):
2885      return self.evaluate(a).to_list()
2886    elif isinstance(a, ops.Tensor):
2887      a = self.evaluate(a)
2888      return a.tolist() if isinstance(a, np.ndarray) else a
2889    elif isinstance(a, np.ndarray):
2890      return a.tolist()
2891    elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
2892      return a.to_list()
2893    else:
2894      return np.array(a).tolist()
2895
2896  def _assertRaggedEqual(self, a, b, msg):
2897    """Asserts that two ragged tensors are equal."""
2898    a_list = self._GetPyList(a)
2899    b_list = self._GetPyList(b)
2900    self.assertEqual(a_list, b_list, msg)
2901
2902    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
2903      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
2904      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
2905      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
2906
2907  def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
2908    a_list = self._GetPyList(a)
2909    b_list = self._GetPyList(b)
2910    self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
2911
2912    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
2913      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
2914      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
2915      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
2916
2917  def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
2918    self.assertEqual(type(a), type(b))
2919    if isinstance(a, (list, tuple)):
2920      self.assertLen(a, len(b), "Length differs for %s" % path)
2921      for i in range(len(a)):
2922        self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
2923                                       "%s[%s]" % (path, i))
2924    else:
2925      self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
2926
2927  # Fix Python 3+ compatibility issues
2928  if not six.PY2:
2929    # pylint: disable=invalid-name
2930
2931    # Silence a deprecation warning
2932    assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
2933
2934    # assertItemsEqual is assertCountEqual as of 3.2.
2935    assertItemsEqual = googletest.TestCase.assertCountEqual
2936
2937    # pylint: enable=invalid-name
2938
2939  @contextlib.contextmanager
2940  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
2941    """Set the session and its graph to global default and constrain devices."""
2942    if context.executing_eagerly():
2943      yield None
2944    else:
2945      with sess.graph.as_default(), sess.as_default():
2946        if force_gpu:
2947          # Use the name of an actual device if one is detected, or
2948          # '/device:GPU:0' otherwise
2949          gpu_name = gpu_device_name()
2950          if not gpu_name:
2951            gpu_name = "/device:GPU:0"
2952          with sess.graph.device(gpu_name):
2953            yield sess
2954        elif use_gpu:
2955          yield sess
2956        else:
2957          with sess.graph.device("/device:CPU:0"):
2958            yield sess
2959
2960  def _create_session(self, graph, config, force_gpu):
2961    """See session() for details."""
2962
2963    def prepare_config(config):
2964      """Returns a config for sessions.
2965
2966      Args:
2967        config: An optional config_pb2.ConfigProto to use to configure the
2968          session.
2969
2970      Returns:
2971        A config_pb2.ConfigProto object.
2972      """
2973      # TODO(b/114333779): Enforce allow_soft_placement=False when
2974      # use_gpu=False. Currently many tests rely on the fact that any device
2975      # will be used even when a specific device is supposed to be used.
2976      allow_soft_placement = not force_gpu
2977      if config is None:
2978        config = context.context().config
2979        config.allow_soft_placement = allow_soft_placement
2980      elif not allow_soft_placement and config.allow_soft_placement:
2981        config_copy = context.context().config
2982        config = config_copy
2983        config.allow_soft_placement = False
2984      # Don't perform optimizations for tests so we don't inadvertently run
2985      # gpu ops on cpu
2986      config.graph_options.optimizer_options.opt_level = -1
2987      # Disable Grappler constant folding since some tests & benchmarks
2988      # use constant input and become meaningless after constant folding.
2989      # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
2990      # GRAPPLER TEAM.
2991      config.graph_options.rewrite_options.constant_folding = (
2992          rewriter_config_pb2.RewriterConfig.OFF)
2993      config.graph_options.rewrite_options.pin_to_host_optimization = (
2994          rewriter_config_pb2.RewriterConfig.OFF)
2995      return config
2996
2997    return ErrorLoggingSession(graph=graph, config=prepare_config(config))
2998
2999  def _get_cached_session(self,
3000                          graph=None,
3001                          config=None,
3002                          force_gpu=False,
3003                          crash_if_inconsistent_args=True):
3004    """See cached_session() for documentation."""
3005    if self._cached_session is None:
3006      sess = self._create_session(
3007          graph=graph, config=config, force_gpu=force_gpu)
3008      self._cached_session = sess
3009      self._cached_graph = graph
3010      self._cached_config = config
3011      self._cached_force_gpu = force_gpu
3012      return sess
3013    else:
3014      if crash_if_inconsistent_args and self._cached_graph is not graph:
3015        raise ValueError("The graph used to get the cached session is "
3016                         "different than the one that was used to create the "
3017                         "session. Maybe create a new session with "
3018                         "self.session()")
3019      if crash_if_inconsistent_args and self._cached_config is not config:
3020        raise ValueError("The config used to get the cached session is "
3021                         "different than the one that was used to create the "
3022                         "session. Maybe create a new session with "
3023                         "self.session()")
3024      if crash_if_inconsistent_args and (self._cached_force_gpu is
3025                                         not force_gpu):
3026        raise ValueError(
3027            "The force_gpu value used to get the cached session is "
3028            "different than the one that was used to create the "
3029            "session. Maybe create a new session with "
3030            "self.session()")
3031      return self._cached_session
3032
3033
3034@tf_export("test.create_local_cluster")
3035def create_local_cluster(num_workers,
3036                         num_ps,
3037                         protocol="grpc",
3038                         worker_config=None,
3039                         ps_config=None):
3040  """Create and start local servers and return the associated `Server` objects.
3041
3042  "PS" stands for "parameter server": a task responsible for storing and
3043  updating the model's parameters. Other tasks send updates to these parameters
3044  as they work on optimizing the parameters. This particular division of labor
3045  between tasks is not required, but is common for distributed training.
3046
3047  Read more at https://www.tensorflow.org/guide/extend/architecture
3048
3049  ![components](https://www.tensorflow.org/images/diag1.svg "components")
3050
3051
3052  Figure illustrates the interaction of these components.
3053  "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services.
3054
3055
3056  Example:
3057  ```python
3058  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
3059
3060  worker_sessions = [tf.compat.v1.Session(w.target) for w in workers]
3061
3062  with tf.device("/job:ps/task:0"):
3063    ...
3064  with tf.device("/job:ps/task:1"):
3065    ...
3066  with tf.device("/job:worker/task:0"):
3067    ...
3068  with tf.device("/job:worker/task:1"):
3069    ...
3070
3071  worker_sessions[0].run(...)
3072  ```
3073
3074  Args:
3075    num_workers: Number of worker servers to start.
3076    num_ps: Number of PS servers to start.
3077    protocol: Communication protocol. Allowed values are documented in the
3078      documentation of `tf.distribute.Server`.
3079    worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be
3080      used to instantiate multiple devices etc.
3081    ps_config: (optional) `tf.ConfigProto` to initialize PS servers.
3082
3083  Returns:
3084    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
3085    of `num_workers` objects of type `tf.distribute.Server` (all running
3086    locally);
3087    and `ps_servers` is a list of `num_ps` objects of similar type.
3088
3089  Raises:
3090    ImportError: if portpicker module was not found at load time
3091  """
3092  import portpicker  # pylint: disable=g-import-not-at-top
3093  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
3094  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
3095  cluster_dict = {
3096      "worker": ["localhost:%s" % port for port in worker_ports],
3097      "ps": ["localhost:%s" % port for port in ps_ports]
3098  }
3099  cs = server_lib.ClusterSpec(cluster_dict)
3100
3101  workers = [
3102      server_lib.Server(
3103          cs,
3104          job_name="worker",
3105          protocol=protocol,
3106          task_index=ix,
3107          config=worker_config,
3108          start=True) for ix in range(num_workers)
3109  ]
3110  ps_servers = [
3111      server_lib.Server(
3112          cs,
3113          job_name="ps",
3114          protocol=protocol,
3115          task_index=ix,
3116          config=ps_config,
3117          start=True) for ix in range(num_ps)
3118  ]
3119
3120  return workers, ps_servers
3121
3122
3123def get_node_def_from_graph(node_name, graph_def):
3124  """Returns the `NodeDef` instance for given node name in the graph def.
3125
3126  This method explores only the NodeDefs in `graph_def.node`.
3127
3128  Args:
3129    node_name: Name of the NodeDef to search for.
3130    graph_def: An instance of `GraphDef` proto.
3131
3132  Returns:
3133    the `NodeDef` instance whose name field matches the given node_name or None.
3134  """
3135  for node_def in graph_def.node:
3136    if node_def.name == node_name:
3137      return node_def
3138  return None
3139
3140
3141def set_producer_version(graph, producer_version):
3142  """Sets graph.graph_def_versions.producer to `producer_version`."""
3143  # The C API doesn't expose altering GraphDefVersions. We can indirectly set
3144  # it via import_graph_def though.
3145  graph_def = graph_pb2.GraphDef()
3146  graph_def.versions.producer = producer_version
3147  with graph.as_default():
3148    importer.import_graph_def(graph_def)
3149  assert graph.graph_def_versions.producer, producer_version
3150