• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import gc
24import math
25import random
26import re
27import tempfile
28import threading
29
30import numpy as np
31import six
32
33_portpicker_import_error = None
34try:
35  import portpicker  # pylint: disable=g-import-not-at-top
36except ImportError as _error:
37  _portpicker_import_error = _error
38  portpicker = None
39
40# pylint: disable=g-import-not-at-top
41from google.protobuf import descriptor_pool
42from google.protobuf import text_format
43
44from tensorflow.core.framework import graph_pb2
45from tensorflow.core.protobuf import config_pb2
46from tensorflow.core.protobuf import rewriter_config_pb2
47from tensorflow.python import pywrap_tensorflow
48from tensorflow.python.client import device_lib
49from tensorflow.python.client import session
50from tensorflow.python.eager import backprop
51from tensorflow.python.eager import context
52from tensorflow.python.eager import tape  # pylint: disable=unused-import
53from tensorflow.python.framework import device as pydev
54from tensorflow.python.framework import dtypes
55from tensorflow.python.framework import errors
56from tensorflow.python.framework import importer
57from tensorflow.python.framework import ops
58from tensorflow.python.framework import random_seed
59from tensorflow.python.framework import versions
60from tensorflow.python.ops import array_ops
61from tensorflow.python.ops import resource_variable_ops
62from tensorflow.python.ops import variables
63from tensorflow.python.platform import googletest
64from tensorflow.python.platform import tf_logging as logging
65from tensorflow.python.training import server_lib
66from tensorflow.python.util import compat
67from tensorflow.python.util import nest
68from tensorflow.python.util.protobuf import compare
69from tensorflow.python.util.tf_export import tf_export
70
71
72@tf_export("test.gpu_device_name")
73def gpu_device_name():
74  """Returns the name of a GPU device if available or the empty string."""
75  for x in device_lib.list_local_devices():
76    if x.device_type == "GPU" or x.device_type == "SYCL":
77      return compat.as_str(x.name)
78  return ""
79
80
81def assert_ops_in_graph(expected_ops, graph):
82  """Assert all expected operations are found.
83
84  Args:
85    expected_ops: `dict<string, string>` of op name to op type.
86    graph: Graph to check.
87  Returns:
88    `dict<string, node>` of node name to node.
89
90  Raises:
91    ValueError: If the expected ops are not present in the graph.
92  """
93  actual_ops = {}
94  gd = graph.as_graph_def()
95  for node in gd.node:
96    if node.name in expected_ops:
97      if expected_ops[node.name] != node.op:
98        raise ValueError("Expected op for node %s is different. %s vs %s" %
99                         (node.name, expected_ops[node.name], node.op))
100      actual_ops[node.name] = node
101  if set(expected_ops.keys()) != set(actual_ops.keys()):
102    raise ValueError("Not all expected ops are present. Expected %s, found %s" %
103                     (expected_ops.keys(), actual_ops.keys()))
104  return actual_ops
105
106
107@tf_export("test.assert_equal_graph_def")
108def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
109  """Asserts that two `GraphDef`s are (mostly) the same.
110
111  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
112  nodes, attrs, and control inputs.  Node names are used to match up nodes
113  between the graphs, so the naming of nodes must be consistent.
114
115  Args:
116    actual: The `GraphDef` we have.
117    expected: The `GraphDef` we expected.
118    checkpoint_v2: boolean determining whether to ignore randomized attribute
119        values that appear in V2 checkpoints.
120
121  Raises:
122    AssertionError: If the `GraphDef`s do not match.
123    TypeError: If either argument is not a `GraphDef`.
124  """
125  if not isinstance(actual, graph_pb2.GraphDef):
126    raise TypeError(
127        "Expected tf.GraphDef for actual, got %s" % type(actual).__name__)
128  if not isinstance(expected, graph_pb2.GraphDef):
129    raise TypeError(
130        "Expected tf.GraphDef for expected, got %s" % type(expected).__name__)
131
132  if checkpoint_v2:
133    _strip_checkpoint_v2_randomized(actual)
134    _strip_checkpoint_v2_randomized(expected)
135
136  diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
137                                                expected.SerializeToString())
138  if diff:
139    raise AssertionError(compat.as_str(diff))
140
141
142def assert_meta_graph_protos_equal(tester, a, b):
143  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
144  # Carefully check the collection_defs
145  tester.assertEqual(set(a.collection_def), set(b.collection_def))
146  collection_keys = a.collection_def.keys()
147  for k in collection_keys:
148    a_value = a.collection_def[k]
149    b_value = b.collection_def[k]
150    proto_type = ops.get_collection_proto_type(k)
151    if proto_type:
152      a_proto = proto_type()
153      b_proto = proto_type()
154      # Number of entries in the collections is the same
155      tester.assertEqual(
156          len(a_value.bytes_list.value), len(b_value.bytes_list.value))
157      for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
158                                              b_value.bytes_list.value):
159        a_proto.ParseFromString(a_value_item)
160        b_proto.ParseFromString(b_value_item)
161        tester.assertProtoEquals(a_proto, b_proto)
162    else:
163      tester.assertEquals(a_value, b_value)
164  # Compared the fields directly, remove their raw values from the
165  # proto comparison below.
166  a.ClearField("collection_def")
167  b.ClearField("collection_def")
168
169  # Check the graph_defs.
170  assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
171  # Check graph_def versions (ignored by assert_equal_graph_def).
172  tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
173  # Compared the fields directly, remove their raw values from the
174  # proto comparison below.
175  a.ClearField("graph_def")
176  b.ClearField("graph_def")
177
178  tester.assertProtoEquals(a, b)
179
180
181# Matches attributes named via _SHARDED_SUFFIX in
182# tensorflow/python/training/saver.py
183_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
184
185
186def _strip_checkpoint_v2_randomized(graph_def):
187  for node in graph_def.node:
188    delete_keys = []
189    for attr_key in node.attr:
190      attr_tensor_value = node.attr[attr_key].tensor
191      if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
192        attr_tensor_string_value = attr_tensor_value.string_val[0]
193        if (attr_tensor_string_value and
194            re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
195          delete_keys.append(attr_key)
196    for attr_key in delete_keys:
197      del node.attr[attr_key]
198
199
200def IsGoogleCudaEnabled():
201  return pywrap_tensorflow.IsGoogleCudaEnabled()
202
203
204def CudaSupportsHalfMatMulAndConv():
205  return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
206
207
208def InstallStackTraceHandler():
209  pywrap_tensorflow.InstallStacktraceHandler()
210
211
212def NHWCToNCHW(input_tensor):
213  """Converts the input from the NHWC format to NCHW.
214
215  Args:
216    input_tensor: a 4- or 5-D tensor, or an array representing shape
217
218  Returns:
219    converted tensor or shape array
220  """
221  # tensor dim -> new axis order
222  new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
223  if isinstance(input_tensor, ops.Tensor):
224    ndims = input_tensor.shape.ndims
225    return array_ops.transpose(input_tensor, new_axes[ndims])
226  else:
227    ndims = len(input_tensor)
228    return [input_tensor[a] for a in new_axes[ndims]]
229
230
231def NHWCToNCHW_VECT_C(input_shape_or_tensor):
232  """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
233
234  Note: Does not include quantization or type conversion steps, which should
235  be applied afterwards.
236
237  Args:
238    input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
239
240  Returns:
241    tensor or shape array transformed into NCHW_VECT_C
242
243  Raises:
244    ValueError: if last dimension of `input_shape_or_tensor` is not evenly
245        divisible by 4.
246  """
247  permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
248  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
249  temp_shape = (
250      input_shape_or_tensor.shape.as_list()
251      if is_tensor else input_shape_or_tensor)
252  if temp_shape[-1] % 4 != 0:
253    raise ValueError(
254        "Last dimension of input must be evenly divisible by 4 to convert to "
255        "NCHW_VECT_C.")
256  temp_shape[-1] //= 4
257  temp_shape.append(4)
258  permutation = permutations[len(temp_shape)]
259  if is_tensor:
260    t = array_ops.reshape(input_shape_or_tensor, temp_shape)
261    return array_ops.transpose(t, permutation)
262  else:
263    return [temp_shape[a] for a in permutation]
264
265
266def NCHW_VECT_CToNHWC(input_shape_or_tensor):
267  """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
268
269  Note: Does not include de-quantization or type conversion steps, which should
270  be applied beforehand.
271
272  Args:
273    input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
274
275  Returns:
276    tensor or shape array transformed into NHWC
277
278  Raises:
279    ValueError: if last dimension of `input_shape_or_tensor` is not 4.
280  """
281  permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
282  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
283  input_shape = (
284      input_shape_or_tensor.shape.as_list()
285      if is_tensor else input_shape_or_tensor)
286  if input_shape[-1] != 4:
287    raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
288  permutation = permutations[len(input_shape)]
289  nhwc_shape = [input_shape[a] for a in permutation[:-1]]
290  nhwc_shape[-1] *= input_shape[-1]
291  if is_tensor:
292    t = array_ops.transpose(input_shape_or_tensor, permutation)
293    return array_ops.reshape(t, nhwc_shape)
294  else:
295    return nhwc_shape
296
297
298def NCHWToNHWC(input_tensor):
299  """Converts the input from the NCHW format to NHWC.
300
301  Args:
302    input_tensor: a 4- or 5-D tensor, or an array representing shape
303
304  Returns:
305    converted tensor or shape array
306  """
307  # tensor dim -> new axis order
308  new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
309  if isinstance(input_tensor, ops.Tensor):
310    ndims = input_tensor.shape.ndims
311    return array_ops.transpose(input_tensor, new_axes[ndims])
312  else:
313    ndims = len(input_tensor)
314    return [input_tensor[a] for a in new_axes[ndims]]
315
316
317# TODO(skyewm): remove this eventually
318# pylint: disable=protected-access
319def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
320  prev_value = ops._USE_C_API
321  ops._USE_C_API = use_c_api
322  try:
323    # Reset the default graph so it has the C API enabled. We call
324    # reset_default_graph() instead of creating a new default Graph context to
325    # make this robust to tests that call reset_default_graph(), which requires
326    # that the current default graph isn't nested.
327    ops.reset_default_graph()
328    fn(*args, **kwargs)
329  finally:
330    ops._USE_C_API = prev_value
331    # Make sure default graph reflects prev_value in case next test doesn't call
332    # reset_default_graph().
333    ops.reset_default_graph()
334# pylint: disable=protected-access
335
336
337def c_api_and_cuda_enabled():
338  return ops._USE_C_API and IsGoogleCudaEnabled()
339
340
341def skip_if(condition):
342  """Skips the decorated function if condition is or evaluates to True.
343
344  Args:
345    condition: Either an expression that can be used in "if not condition"
346               statement, or a callable whose result should be a boolean.
347  Returns:
348    The wrapped function
349  """
350
351  def real_skip_if(fn):
352
353    def wrapper(*args, **kwargs):
354      if callable(condition):
355        skip = condition()
356      else:
357        skip = condition
358      if not skip:
359        fn(*args, **kwargs)
360
361    return wrapper
362
363  return real_skip_if
364
365
366# TODO(skyewm): remove this eventually
367def disable_c_api(fn):
368  """Decorator for disabling the C API on a test.
369
370  Note this disables the C API after running the test class's setup/teardown
371  methods.
372
373  Args:
374    fn: the function to be wrapped
375
376  Returns:
377    The wrapped function
378  """
379
380  def wrapper(*args, **kwargs):
381    _use_c_api_wrapper(fn, False, *args, **kwargs)
382
383  return wrapper
384
385
386# TODO(skyewm): remove this eventually
387def enable_c_api(fn):
388  """Decorator for enabling the C API on a test.
389
390  Note this enables the C API after running the test class's setup/teardown
391  methods.
392
393  Args:
394    fn: the function to be wrapped
395
396  Returns:
397    The wrapped function
398  """
399
400  def wrapper(*args, **kwargs):
401    _use_c_api_wrapper(fn, True, *args, **kwargs)
402
403  return wrapper
404
405
406# This decorator is a hacky way to run all the test methods in a decorated
407# class with and without C API enabled.
408# TODO(iga): Remove this and its uses once we switch to using C API by default.
409def with_c_api(cls):
410  """Adds methods that call original methods but with C API enabled.
411
412  Note this enables the C API in new methods after running the test class's
413  setup method. This can be a problem if some objects are created in it
414  before the C API is enabled.
415
416  Args:
417    cls: class to decorate
418
419  Returns:
420    cls with new test methods added
421  """
422  for name, value in cls.__dict__.copy().items():
423    if callable(value) and name.startswith("test"):
424      setattr(cls, name + "WithCApi", enable_c_api(value))
425  return cls
426
427
428def assert_no_new_tensors(f):
429  """Decorator for asserting that no new Tensors persist after a test.
430
431  Mainly useful for checking that code using the Python C API has correctly
432  manipulated reference counts.
433
434  Clears the caches that it knows about, runs the garbage collector, then checks
435  that there are no Tensor or Tensor-like objects still around. This includes
436  Tensors to which something still has a reference (e.g. from missing
437  Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
438  of the objects has __del__ defined).
439
440  Args:
441    f: The test case to run.
442  Returns:
443    The decorated test case.
444  """
445
446  def decorator(self, **kwargs):
447    """Finds existing Tensors, runs the test, checks for new Tensors."""
448
449    def _is_tensor(obj):
450      try:
451        return (isinstance(obj, ops.Tensor) or
452                isinstance(obj, variables.Variable))
453      except ReferenceError:
454        # If the object no longer exists, we don't care about it.
455        return False
456
457    tensors_before = set(id(obj) for obj in gc.get_objects() if _is_tensor(obj))
458    outside_graph_key = ops.get_default_graph()._graph_key
459    with ops.Graph().as_default():
460      # Run the test in a new graph so that collections get cleared when it's
461      # done, but inherit the graph key so optimizers behave.
462      ops.get_default_graph()._graph_key = outside_graph_key
463      f(self, **kwargs)
464    # Make an effort to clear caches, which would otherwise look like leaked
465    # Tensors.
466    backprop._zeros_cache.flush()
467    context.get_default_context().scalar_cache().clear()
468    gc.collect()
469    tensors_after = [
470        obj for obj in gc.get_objects()
471        if _is_tensor(obj) and id(obj) not in tensors_before
472    ]
473    if tensors_after:
474      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
475          len(tensors_after),
476          str(tensors_after),
477      )))
478
479  return decorator
480
481
482def assert_no_garbage_created(f):
483  """Test method decorator to assert that no garbage has been created.
484
485  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
486  cannot be un-set (i.e. will disable garbage collection for any other unit
487  tests in the same file/shard).
488
489  Args:
490    f: The function to decorate.
491  Returns:
492    The decorated function.
493  """
494
495  def decorator(self, **kwargs):
496    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
497    gc.disable()
498    previous_debug_flags = gc.get_debug()
499    gc.set_debug(gc.DEBUG_SAVEALL)
500    gc.collect()
501    previous_garbage = len(gc.garbage)
502    f(self, **kwargs)
503    gc.collect()
504    # This will fail if any garbage has been created, typically because of a
505    # reference cycle.
506    self.assertEqual(previous_garbage, len(gc.garbage))
507    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
508    # be nice to be able to decorate arbitrary tests in a large test suite and
509    # not hold on to every object in other tests.
510    gc.set_debug(previous_debug_flags)
511    gc.enable()
512
513  return decorator
514
515
516def run_in_graph_and_eager_modes(__unused__=None,
517                                 graph=None,
518                                 config=None,
519                                 use_gpu=False,
520                                 force_gpu=False,
521                                 reset_test=True,
522                                 assert_no_eager_garbage=False):
523  """Runs the test in both graph and eager modes.
524
525  Args:
526    __unused__: Prevents sliently skipping tests.
527    graph: Optional graph to use during the returned session.
528    config: An optional config_pb2.ConfigProto to use to configure the
529      session.
530    use_gpu: If True, attempt to run as many ops as possible on GPU.
531    force_gpu: If True, pin all ops to `/device:GPU:0`.
532    reset_test: If True, tearDown and SetUp the test case again.
533    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
534      collector and asserts that no extra garbage has been created when running
535      the test in eager mode. This will fail if there are reference cycles
536      (e.g. a = []; a.append(a)). Off by default because some tests may create
537      garbage for legitimate reasons (e.g. they define a class which inherits
538      from `object`), and because DEBUG_SAVEALL is sticky in some Python
539      interpreters (meaning that tests which rely on objects being collected
540      elsewhere in the unit test file will not work). Additionally, checks that
541      nothing still has a reference to Tensors that the test allocated.
542  Returns:
543    Returns a decorator that will run the decorated test function
544        using both a graph and using eager execution.
545  """
546
547  assert not __unused__, "Add () after run_in_graph_and_eager_modes."
548
549  def decorator(f):
550    """Test method decorator."""
551
552    def decorated(self, **kwargs):
553      """Decorated the test method."""
554      with context.graph_mode():
555        with self.test_session(graph, config, use_gpu, force_gpu):
556          f(self, **kwargs)
557
558      if reset_test:
559        # This decorator runs the wrapped test twice.
560        # Reset the test environment between runs.
561        self.tearDown()
562        self.setUp()
563
564      def run_eager_mode(self, **kwargs):
565        if force_gpu:
566          gpu_name = gpu_device_name()
567          if not gpu_name:
568            gpu_name = "/device:GPU:0"
569          with context.device(gpu_name):
570            f(self)
571        elif use_gpu:
572          # TODO(xpan): Support softplacement and gpu by default when available.
573          f(self, **kwargs)
574        else:
575          with context.device("/device:CPU:0"):
576            f(self, **kwargs)
577
578      if assert_no_eager_garbage:
579        run_eager_mode = assert_no_new_tensors(
580            assert_no_garbage_created(run_eager_mode))
581
582      with context.eager_mode():
583        with ops.Graph().as_default():
584          run_eager_mode(self, **kwargs)
585
586    return decorated
587
588  return decorator
589
590
591@tf_export("test.is_gpu_available")
592def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
593  """Returns whether TensorFlow can access a GPU.
594
595  Args:
596    cuda_only: limit the search to CUDA gpus.
597    min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
598      CUDA compute capability required, or None if no requirement.
599
600  Returns:
601    True iff a gpu device of the requested kind is available.
602  """
603
604  def compute_capability_from_device_desc(device_desc):
605    # TODO(jingyue): The device description generator has to be in sync with
606    # this file. Another option is to put compute capability in
607    # DeviceAttributes, but I avoided that to keep DeviceAttributes
608    # target-independent. Reconsider this option when we have more things like
609    # this to keep in sync.
610    # LINT.IfChange
611    match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc)
612    # LINT.ThenChange(//tensorflow/core/\
613    #                 common_runtime/gpu/gpu_device.cc)
614    if not match:
615      return 0, 0
616    return int(match.group(1)), int(match.group(2))
617
618  for local_device in device_lib.list_local_devices():
619    if local_device.device_type == "GPU":
620      if (min_cuda_compute_capability is None or
621          compute_capability_from_device_desc(local_device.physical_device_desc)
622          >= min_cuda_compute_capability):
623        return True
624    if local_device.device_type == "SYCL" and not cuda_only:
625      return True
626  return False
627
628
629@contextlib.contextmanager
630def device(use_gpu):
631  """Uses gpu when requested and available."""
632  if use_gpu and is_gpu_available():
633    dev = "/device:GPU:0"
634  else:
635    dev = "/device:CPU:0"
636  with ops.device(dev):
637    yield
638
639
640@tf_export("test.TestCase")
641class TensorFlowTestCase(googletest.TestCase):
642  """Base class for tests that need to test TensorFlow.
643  """
644
645  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
646    super(TensorFlowTestCase, self).__init__(methodName)
647    self._threads = []
648    self._tempdir = None
649    self._cached_session = None
650
651  def setUp(self):
652    self._ClearCachedSession()
653    random.seed(random_seed.DEFAULT_GRAPH_SEED)
654    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
655    # Note: The following line is necessary because some test methods may error
656    # out from within nested graph contexts (e.g., via assertRaises and
657    # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
658    # under certain versions of Python. That would cause
659    # ops.reset_default_graph() to throw an exception if the stack were not
660    # cleared first.
661    ops._default_graph_stack.reset()  # pylint: disable=protected-access
662    ops.reset_default_graph()
663    random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
664
665  def tearDown(self):
666    for thread in self._threads:
667      thread.check_termination()
668
669    self._ClearCachedSession()
670
671  def _ClearCachedSession(self):
672    if self._cached_session is not None:
673      self._cached_session.close()
674      self._cached_session = None
675
676  def get_temp_dir(self):
677    """Returns a unique temporary directory for the test to use.
678
679    If you call this method multiple times during in a test, it will return the
680    same folder. However, across different runs the directories will be
681    different. This will ensure that across different runs tests will not be
682    able to pollute each others environment.
683    If you need multiple unique directories within a single test, you should
684    use tempfile.mkdtemp as follows:
685      tempfile.mkdtemp(dir=self.get_temp_dir()):
686
687    Returns:
688      string, the path to the unique temporary directory created for this test.
689    """
690    if not self._tempdir:
691      self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
692    return self._tempdir
693
694  def _AssertProtoEquals(self, a, b, msg=None):
695    """Asserts that a and b are the same proto.
696
697    Uses ProtoEq() first, as it returns correct results
698    for floating point attributes, and then use assertProtoEqual()
699    in case of failure as it provides good error messages.
700
701    Args:
702      a: a proto.
703      b: another proto.
704      msg: Optional message to report on failure.
705    """
706    if not compare.ProtoEq(a, b):
707      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
708
709  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
710    """Asserts that message is same as parsed expected_message_ascii.
711
712    Creates another prototype of message, reads the ascii message into it and
713    then compares them using self._AssertProtoEqual().
714
715    Args:
716      expected_message_maybe_ascii: proto message in original or ascii form.
717      message: the message to validate.
718      msg: Optional message to report on failure.
719    """
720    msg = msg if msg else ""
721    if isinstance(expected_message_maybe_ascii, type(message)):
722      expected_message = expected_message_maybe_ascii
723      self._AssertProtoEquals(expected_message, message)
724    elif isinstance(expected_message_maybe_ascii, str):
725      expected_message = type(message)()
726      text_format.Merge(
727          expected_message_maybe_ascii,
728          expected_message,
729          descriptor_pool=descriptor_pool.Default())
730      self._AssertProtoEquals(expected_message, message, msg=msg)
731    else:
732      assert False, ("Can't compare protos of type %s and %s. %s" %
733                     (type(expected_message_maybe_ascii), type(message), msg))
734
735  def assertProtoEqualsVersion(
736      self,
737      expected,
738      actual,
739      producer=versions.GRAPH_DEF_VERSION,
740      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
741      msg=None):
742    expected = "versions { producer: %d min_consumer: %d };\n%s" % (
743        producer, min_consumer, expected)
744    self.assertProtoEquals(expected, actual, msg=msg)
745
746  def assertStartsWith(self, actual, expected_start, msg=None):
747    """Assert that actual.startswith(expected_start) is True.
748
749    Args:
750      actual: str
751      expected_start: str
752      msg: Optional message to report on failure.
753    """
754    if not actual.startswith(expected_start):
755      fail_msg = "%r does not start with %r" % (actual, expected_start)
756      fail_msg += " : %r" % (msg) if msg else ""
757      self.fail(fail_msg)
758
759  def _eval_tensor(self, tensor):
760    if tensor is None:
761      return None
762    elif isinstance(tensor, ops.EagerTensor):
763      return tensor.numpy()
764    elif isinstance(tensor, resource_variable_ops.ResourceVariable):
765      return tensor.read_value().numpy()
766    elif callable(tensor):
767      return self._eval_helper(tensor())
768    else:
769      raise ValueError("Unsupported type %s." % type(tensor))
770
771  def _eval_helper(self, tensors):
772    if tensors is None:
773      return None
774    return nest.map_structure(self._eval_tensor, tensors)
775
776  def evaluate(self, tensors):
777    """Evaluates tensors and returns numpy values.
778
779    Args:
780      tensors: A Tensor or a nested list/tuple of Tensors.
781
782    Returns:
783      tensors numpy values.
784    """
785    if context.in_eager_mode():
786      return self._eval_helper(tensors)
787    else:
788      sess = ops.get_default_session()
789      if sess is None:
790        with self.test_session() as sess:
791          return sess.run(tensors)
792      else:
793        return sess.run(tensors)
794
795  # pylint: disable=g-doc-return-or-yield
796  @contextlib.contextmanager
797  def test_session(self,
798                   graph=None,
799                   config=None,
800                   use_gpu=False,
801                   force_gpu=False):
802    """Returns a TensorFlow Session for use in executing tests.
803
804    This method should be used for all functional tests.
805
806    This method behaves different than session.Session: for performance reasons
807    `test_session` will by default (if `graph` is None) reuse the same session
808    across tests. This means you may want to either call the function
809    `reset_default_graph()` before tests, or if creating an explicit new graph,
810    pass it here (simply setting it with `as_default()` won't do it), which will
811    trigger the creation of a new session.
812
813    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
814    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
815    `use_gpu`
816    is True, TensorFlow tries to run as many ops on the GPU as possible. If both
817    `force_gpu and `use_gpu` are False, all ops are pinned to the CPU.
818
819    Example:
820    ```python
821    class MyOperatorTest(test_util.TensorFlowTestCase):
822      def testMyOperator(self):
823        with self.test_session(use_gpu=True):
824          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
825          result = MyOperator(valid_input).eval()
826          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
827          invalid_input = [-1.0, 2.0, 7.0]
828          with self.assertRaisesOpError("negative input not supported"):
829            MyOperator(invalid_input).eval()
830    ```
831
832    Args:
833      graph: Optional graph to use during the returned session.
834      config: An optional config_pb2.ConfigProto to use to configure the
835        session.
836      use_gpu: If True, attempt to run as many ops as possible on GPU.
837      force_gpu: If True, pin all ops to `/device:GPU:0`.
838
839    Returns:
840      A Session object that should be used as a context manager to surround
841      the graph building and execution code in a test case.
842    """
843    if self.id().endswith(".test_session"):
844      self.skipTest("Not a test.")
845
846    def prepare_config(config):
847      """Returns a config for sessions.
848
849      Args:
850        config: An optional config_pb2.ConfigProto to use to configure the
851          session.
852      Returns:
853        A config_pb2.ConfigProto object.
854      """
855      if config is None:
856        config = config_pb2.ConfigProto()
857        config.allow_soft_placement = not force_gpu
858        config.gpu_options.per_process_gpu_memory_fraction = 0.3
859      elif force_gpu and config.allow_soft_placement:
860        config = config_pb2.ConfigProto().CopyFrom(config)
861        config.allow_soft_placement = False
862      # Don't perform optimizations for tests so we don't inadvertently run
863      # gpu ops on cpu
864      config.graph_options.optimizer_options.opt_level = -1
865      config.graph_options.rewrite_options.constant_folding = (
866          rewriter_config_pb2.RewriterConfig.OFF)
867      config.graph_options.rewrite_options.arithmetic_optimization = (
868          rewriter_config_pb2.RewriterConfig.OFF)
869      return config
870
871    if graph is None:
872      if self._cached_session is None:
873        self._cached_session = session.Session(
874            graph=None, config=prepare_config(config))
875      sess = self._cached_session
876      with sess.graph.as_default(), sess.as_default():
877        if force_gpu:
878          # Use the name of an actual device if one is detected, or '/device:GPU:0'
879          # otherwise
880          gpu_name = gpu_device_name()
881          if not gpu_name:
882            gpu_name = "/device:GPU:0"
883          with sess.graph.device(gpu_name):
884            yield sess
885        elif use_gpu:
886          yield sess
887        else:
888          with sess.graph.device("/cpu:0"):
889            yield sess
890    else:
891      with session.Session(graph=graph, config=prepare_config(config)) as sess:
892        if force_gpu:
893          # Use the name of an actual device if one is detected, or '/device:GPU:0'
894          # otherwise
895          gpu_name = gpu_device_name()
896          if not gpu_name:
897            gpu_name = "/device:GPU:0"
898          with sess.graph.device(gpu_name):
899            yield sess
900        elif use_gpu:
901          yield sess
902        else:
903          with sess.graph.device("/cpu:0"):
904            yield sess
905
906  # pylint: enable=g-doc-return-or-yield
907
908  class _CheckedThread(object):
909    """A wrapper class for Thread that asserts successful completion.
910
911    This class should be created using the TensorFlowTestCase.checkedThread()
912    method.
913    """
914
915    def __init__(self, testcase, target, args=None, kwargs=None):
916      """Constructs a new instance of _CheckedThread.
917
918      Args:
919        testcase: The TensorFlowTestCase for which this thread is being created.
920        target: A callable object representing the code to be executed in the
921          thread.
922        args: A tuple of positional arguments that will be passed to target.
923        kwargs: A dictionary of keyword arguments that will be passed to target.
924      """
925      self._testcase = testcase
926      self._target = target
927      self._args = () if args is None else args
928      self._kwargs = {} if kwargs is None else kwargs
929      self._thread = threading.Thread(target=self._protected_run)
930      self._exception = None
931
932      self._is_thread_joined = False
933
934    def _protected_run(self):
935      """Target for the wrapper thread. Sets self._exception on failure."""
936      try:
937        self._target(*self._args, **self._kwargs)
938      except Exception as e:  # pylint: disable=broad-except
939        self._exception = e
940
941    def start(self):
942      """Starts the thread's activity.
943
944      This must be called at most once per _CheckedThread object. It arranges
945      for the object's target to be invoked in a separate thread of control.
946      """
947      self._thread.start()
948
949    def join(self):
950      """Blocks until the thread terminates.
951
952      Raises:
953        self._testcase.failureException: If the thread terminates with due to
954          an exception.
955      """
956      self._is_thread_joined = True
957      self._thread.join()
958      if self._exception is not None:
959        self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
960
961    def is_alive(self):
962      """Returns whether the thread is alive.
963
964      This method returns True just before the run() method starts
965      until just after the run() method terminates.
966
967      Returns:
968        True if the thread is alive, otherwise False.
969      """
970      return self._thread.is_alive()
971
972    def check_termination(self):
973      """Returns whether the checked thread was properly used and did terminate.
974
975      Every checked thread should be "join"ed after starting, and before the
976      test tears down. If it is not joined, it is possible the thread will hang
977      and cause flaky failures in tests.
978
979      Raises:
980        self._testcase.failureException: If check_termination was called before
981        thread was joined.
982
983        RuntimeError: If the thread is not terminated. This means thread was not
984        joined with the main thread.
985      """
986      if self._is_thread_joined:
987        if self.is_alive():
988          raise RuntimeError(
989              "Thread was not joined with main thread, and is still running "
990              "when the test finished.")
991      else:
992        self._testcase.fail("A checked thread was not joined.")
993
994  def checkedThread(self, target, args=None, kwargs=None):
995    """Returns a Thread wrapper that asserts 'target' completes successfully.
996
997    This method should be used to create all threads in test cases, as
998    otherwise there is a risk that a thread will silently fail, and/or
999    assertions made in the thread will not be respected.
1000
1001    Args:
1002      target: A callable object to be executed in the thread.
1003      args: The argument tuple for the target invocation. Defaults to ().
1004      kwargs: A dictionary of keyword arguments for the target invocation.
1005        Defaults to {}.
1006
1007    Returns:
1008      A wrapper for threading.Thread that supports start() and join() methods.
1009    """
1010    ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
1011    self._threads.append(ret)
1012    return ret
1013
1014
1015# pylint: enable=invalid-name
1016
1017  def assertNear(self, f1, f2, err, msg=None):
1018    """Asserts that two floats are near each other.
1019
1020    Checks that |f1 - f2| < err and asserts a test failure
1021    if not.
1022
1023    Args:
1024      f1: A float value.
1025      f2: A float value.
1026      err: A float value.
1027      msg: An optional string message to append to the failure message.
1028    """
1029    # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
1030    self.assertTrue(f1 == f2 or math.fabs(f1 - f2) <= err,
1031                    "%f != %f +/- %f%s" % (f1, f2, err, " (%s)" % msg
1032                                           if msg is not None else ""))
1033
1034  def assertArrayNear(self, farray1, farray2, err, msg=None):
1035    """Asserts that two float arrays are near each other.
1036
1037    Checks that for all elements of farray1 and farray2
1038    |f1 - f2| < err.  Asserts a test failure if not.
1039
1040    Args:
1041      farray1: a list of float values.
1042      farray2: a list of float values.
1043      err: a float value.
1044      msg: Optional message to report on failure.
1045    """
1046    self.assertEqual(len(farray1), len(farray2), msg=msg)
1047    for f1, f2 in zip(farray1, farray2):
1048      self.assertNear(float(f1), float(f2), err, msg=msg)
1049
1050  def _NDArrayNear(self, ndarray1, ndarray2, err):
1051    return np.linalg.norm(ndarray1 - ndarray2) < err
1052
1053  def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
1054    """Asserts that two numpy arrays have near values.
1055
1056    Args:
1057      ndarray1: a numpy ndarray.
1058      ndarray2: a numpy ndarray.
1059      err: a float. The maximum absolute difference allowed.
1060      msg: Optional message to report on failure.
1061    """
1062    self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
1063
1064  def _GetNdArray(self, a):
1065    if not isinstance(a, np.ndarray):
1066      a = np.array(a)
1067    return a
1068
1069  def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
1070    a = self._GetNdArray(a)
1071    b = self._GetNdArray(b)
1072    self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." %
1073                     (a.shape, b.shape))
1074    if not np.allclose(a, b, rtol=rtol, atol=atol):
1075      # Prints more details than np.testing.assert_allclose.
1076      #
1077      # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
1078      # checks whether two arrays are element-wise equal within a
1079      # tolerance. The relative difference (rtol * abs(b)) and the
1080      # absolute difference atol are added together to compare against
1081      # the absolute difference between a and b.  Here, we want to
1082      # print out which elements violate such conditions.
1083      cond = np.logical_or(
1084          np.abs(a - b) > atol + rtol * np.abs(b),
1085          np.isnan(a) != np.isnan(b))
1086      if a.ndim:
1087        x = a[np.where(cond)]
1088        y = b[np.where(cond)]
1089        print("not close where = ", np.where(cond))
1090      else:
1091        # np.where is broken for scalars
1092        x, y = a, b
1093      print("not close lhs = ", x)
1094      print("not close rhs = ", y)
1095      print("not close dif = ", np.abs(x - y))
1096      print("not close tol = ", atol + rtol * np.abs(y))
1097      print("dtype = %s, shape = %s" % (a.dtype, a.shape))
1098      # TODO(xpan): There seems to be a bug:
1099      # tensorflow/compiler/tests:binary_ops_test pass with float32
1100      # nan even though the equal_nan is False by default internally.
1101      np.testing.assert_allclose(
1102          a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
1103
1104  def _assertAllCloseRecursive(self,
1105                               a,
1106                               b,
1107                               rtol=1e-6,
1108                               atol=1e-6,
1109                               path=None,
1110                               msg=None):
1111    path = path or []
1112    path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "")
1113    msg = msg if msg else ""
1114
1115    # Check if a and/or b are namedtuples.
1116    if hasattr(a, "_asdict"):
1117      a = a._asdict()
1118    if hasattr(b, "_asdict"):
1119      b = b._asdict()
1120    a_is_dict = isinstance(a, dict)
1121    if a_is_dict != isinstance(b, dict):
1122      raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
1123                       (path_str, path_str, msg))
1124    if a_is_dict:
1125      self.assertItemsEqual(
1126          a.keys(),
1127          b.keys(),
1128          msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
1129          (path_str, a.keys(), path_str, b.keys(), msg))
1130      for k in a:
1131        path.append(k)
1132        self._assertAllCloseRecursive(
1133            a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
1134        del path[-1]
1135    elif isinstance(a, (list, tuple)):
1136      # Try to directly compare a, b as ndarrays; if not work, then traverse
1137      # through the sequence, which is more expensive.
1138      try:
1139        a_as_ndarray = np.array(a)
1140        b_as_ndarray = np.array(b)
1141        self._assertArrayLikeAllClose(
1142            a_as_ndarray,
1143            b_as_ndarray,
1144            rtol=rtol,
1145            atol=atol,
1146            msg="Mismatched value: a%s is different from b%s. %s" %
1147            (path_str, path_str, msg))
1148      except (ValueError, TypeError) as e:
1149        if len(a) != len(b):
1150          raise ValueError(
1151              "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
1152              (path_str, len(a), path_str, len(b), msg))
1153        for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
1154          path.append(str(idx))
1155          self._assertAllCloseRecursive(
1156              a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
1157          del path[-1]
1158    # a and b are ndarray like objects
1159    else:
1160      try:
1161        self._assertArrayLikeAllClose(
1162            a,
1163            b,
1164            rtol=rtol,
1165            atol=atol,
1166            msg="Mismatched value: a%s is different from b%s." % (path_str,
1167                                                                  path_str))
1168      except TypeError as e:
1169        msg = "Error: a%s has %s, but b%s has %s" % (
1170            path_str, type(a), path_str, type(b))
1171        e.args = ((e.args[0] + ' : ' + msg,) + e.args[1:])
1172        raise
1173
1174  def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
1175    """Asserts that two structures of numpy arrays, have near values.
1176
1177    `a` and `b` can be arbitrarily nested structures. A layer of a nested
1178    structure can be a `dict`, `namedtuple`, `tuple` or `list`.
1179
1180    Args:
1181      a: The expected numpy `ndarray`, or anything that can be converted into a
1182          numpy `ndarray`, or any arbitrarily nested of structure of these.
1183      b: The actual numpy `ndarray`, or anything that can be converted into a
1184          numpy `ndarray`, or any arbitrarily nested of structure of these.
1185      rtol: relative tolerance.
1186      atol: absolute tolerance.
1187      msg: Optional message to report on failure.
1188
1189    Raises:
1190      ValueError: if only one of `a[p]` and `b[p]` is a dict or
1191          `a[p]` and `b[p]` have different length, where `[p]` denotes a path
1192          to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
1193          `[p] = [1]['d']`, then `a[p] = (6, 7)`.
1194    """
1195    self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
1196
1197  def assertAllCloseAccordingToType(self,
1198                                    a,
1199                                    b,
1200                                    rtol=1e-6,
1201                                    atol=1e-6,
1202                                    float_rtol=1e-6,
1203                                    float_atol=1e-6,
1204                                    half_rtol=1e-3,
1205                                    half_atol=1e-3,
1206                                    bfloat16_rtol=1e-2,
1207                                    bfloat16_atol=1e-2,
1208                                    msg=None):
1209    """Like assertAllClose, but also suitable for comparing fp16 arrays.
1210
1211    In particular, the tolerance is reduced to 1e-3 if at least
1212    one of the arguments is of type float16.
1213
1214    Args:
1215      a: the expected numpy ndarray or anything can be converted to one.
1216      b: the actual numpy ndarray or anything can be converted to one.
1217      rtol: relative tolerance.
1218      atol: absolute tolerance.
1219      float_rtol: relative tolerance for float32.
1220      float_atol: absolute tolerance for float32.
1221      half_rtol: relative tolerance for float16.
1222      half_atol: absolute tolerance for float16.
1223      bfloat16_rtol: relative tolerance for bfloat16.
1224      bfloat16_atol: absolute tolerance for bfloat16.
1225      msg: Optional message to report on failure.
1226    """
1227    a = self._GetNdArray(a)
1228    b = self._GetNdArray(b)
1229    # types with lower tol are put later to overwrite previous ones.
1230    if (a.dtype == np.float32 or b.dtype == np.float32 or
1231        a.dtype == np.complex64 or b.dtype == np.complex64):
1232      rtol = max(rtol, float_rtol)
1233      atol = max(atol, float_atol)
1234    if a.dtype == np.float16 or b.dtype == np.float16:
1235      rtol = max(rtol, half_rtol)
1236      atol = max(atol, half_atol)
1237    if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
1238        b.dtype == dtypes.bfloat16.as_numpy_dtype):
1239      rtol = max(rtol, bfloat16_rtol)
1240      atol = max(atol, bfloat16_atol)
1241
1242    self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
1243
1244  def assertAllEqual(self, a, b, msg=None):
1245    """Asserts that two numpy arrays have the same values.
1246
1247    Args:
1248      a: the expected numpy ndarray or anything can be converted to one.
1249      b: the actual numpy ndarray or anything can be converted to one.
1250      msg: Optional message to report on failure.
1251    """
1252    msg = msg if msg else ""
1253    a = self._GetNdArray(a)
1254    b = self._GetNdArray(b)
1255    self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s."
1256                     " %s" % (a.shape, b.shape, msg))
1257    same = (a == b)
1258
1259    if a.dtype == np.float32 or a.dtype == np.float64:
1260      same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
1261    if not np.all(same):
1262      # Prints more details than np.testing.assert_array_equal.
1263      diff = np.logical_not(same)
1264      if a.ndim:
1265        x = a[np.where(diff)]
1266        y = b[np.where(diff)]
1267        print("not equal where = ", np.where(diff))
1268      else:
1269        # np.where is broken for scalars
1270        x, y = a, b
1271      print("not equal lhs = ", x)
1272      print("not equal rhs = ", y)
1273      np.testing.assert_array_equal(a, b, err_msg=msg)
1274
1275  # pylint: disable=g-doc-return-or-yield
1276  @contextlib.contextmanager
1277  def assertRaisesWithPredicateMatch(self, exception_type,
1278                                     expected_err_re_or_predicate):
1279    """Returns a context manager to enclose code expected to raise an exception.
1280
1281    If the exception is an OpError, the op stack is also included in the message
1282    predicate search.
1283
1284    Args:
1285      exception_type: The expected type of exception that should be raised.
1286      expected_err_re_or_predicate: If this is callable, it should be a function
1287        of one argument that inspects the passed-in exception and
1288        returns True (success) or False (please fail the test). Otherwise, the
1289        error message is expected to match this regular expression partially.
1290
1291    Returns:
1292      A context manager to surround code that is expected to raise an
1293      exception.
1294    """
1295    if callable(expected_err_re_or_predicate):
1296      predicate = expected_err_re_or_predicate
1297    else:
1298
1299      def predicate(e):
1300        err_str = e.message if isinstance(e, errors.OpError) else str(e)
1301        op = e.op if isinstance(e, errors.OpError) else None
1302        while op is not None:
1303          err_str += "\nCaused by: " + op.name
1304          op = op._original_op  # pylint: disable=protected-access
1305        logging.info("Searching within error strings: '%s' within '%s'",
1306                     expected_err_re_or_predicate, err_str)
1307        return re.search(expected_err_re_or_predicate, err_str)
1308
1309    try:
1310      yield
1311      self.fail(exception_type.__name__ + " not raised")
1312    except Exception as e:  # pylint: disable=broad-except
1313      if not isinstance(e, exception_type) or not predicate(e):
1314        raise AssertionError("Exception of type %s: %s" % (str(type(e)),
1315                                                           str(e)))
1316
1317  # pylint: enable=g-doc-return-or-yield
1318
1319  def assertRaisesOpError(self, expected_err_re_or_predicate):
1320    return self.assertRaisesWithPredicateMatch(errors.OpError,
1321                                               expected_err_re_or_predicate)
1322
1323  def assertShapeEqual(self, np_array, tf_tensor, msg=None):
1324    """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
1325
1326    Args:
1327      np_array: A Numpy ndarray or Numpy scalar.
1328      tf_tensor: A Tensor.
1329      msg: Optional message to report on failure.
1330
1331    Raises:
1332      TypeError: If the arguments have the wrong type.
1333    """
1334    if not isinstance(np_array, (np.ndarray, np.generic)):
1335      raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
1336    if not isinstance(tf_tensor, ops.Tensor):
1337      raise TypeError("tf_tensor must be a Tensor")
1338    self.assertAllEqual(
1339        np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
1340
1341  def assertDeviceEqual(self, device1, device2, msg=None):
1342    """Asserts that the two given devices are the same.
1343
1344    Args:
1345      device1: A string device name or TensorFlow `DeviceSpec` object.
1346      device2: A string device name or TensorFlow `DeviceSpec` object.
1347      msg: Optional message to report on failure.
1348    """
1349    device1 = pydev.canonical_name(device1)
1350    device2 = pydev.canonical_name(device2)
1351    self.assertEqual(device1, device2,
1352                     "Devices %s and %s are not equal. %s" %
1353                     (device1, device2, msg))
1354
1355  # Fix Python 3 compatibility issues
1356  if six.PY3:
1357    # pylint: disable=invalid-name
1358
1359    # Silence a deprecation warning
1360    assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
1361
1362    # assertItemsEqual is assertCountEqual as of 3.2.
1363    assertItemsEqual = googletest.TestCase.assertCountEqual
1364
1365    # pylint: enable=invalid-name
1366
1367
1368@tf_export("test.create_local_cluster")
1369def create_local_cluster(num_workers,
1370                         num_ps,
1371                         protocol="grpc",
1372                         worker_config=None,
1373                         ps_config=None):
1374  """Create and start local servers and return the associated `Server` objects.
1375
1376  Example:
1377  ```python
1378  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
1379
1380  worker_sessions = [tf.Session(w.target) for w in workers]
1381
1382  with tf.device("/job:ps/task:0"):
1383    ...
1384  with tf.device("/job:ps/task:1"):
1385    ...
1386  with tf.device("/job:worker/task:0"):
1387    ...
1388  with tf.device("/job:worker/task:1"):
1389    ...
1390
1391  worker_sessions[0].run(...)
1392  ```
1393
1394  Args:
1395    num_workers: Number of worker servers to start.
1396    num_ps: Number of PS servers to start.
1397    protocol: Communication protocol.  Allowed values are documented in
1398      the documentation of `tf.train.Server`.
1399    worker_config: (optional) ConfigProto to initialize workers. Can be used
1400      to instantiate multiple devices etc.
1401    ps_config: (optional) ConfigProto to initialize PS servers.
1402
1403  Returns:
1404    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
1405    of `num_workers` objects of type `tf.train.Server` (all running locally);
1406    and `ps_servers` is a list of `num_ps` objects of similar type.
1407
1408  Raises:
1409    ImportError: if portpicker module was not found at load time
1410  """
1411  if _portpicker_import_error:
1412    raise _portpicker_import_error  # pylint: disable=raising-bad-type
1413  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
1414  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
1415  cluster_dict = {
1416      "worker": ["localhost:%s" % port for port in worker_ports],
1417      "ps": ["localhost:%s" % port for port in ps_ports]
1418  }
1419  cs = server_lib.ClusterSpec(cluster_dict)
1420
1421  workers = [
1422      server_lib.Server(
1423          cs,
1424          job_name="worker",
1425          protocol=protocol,
1426          task_index=ix,
1427          config=worker_config,
1428          start=True) for ix in range(num_workers)
1429  ]
1430  ps_servers = [
1431      server_lib.Server(
1432          cs,
1433          job_name="ps",
1434          protocol=protocol,
1435          task_index=ix,
1436          config=ps_config,
1437          start=True) for ix in range(num_ps)
1438  ]
1439
1440  return workers, ps_servers
1441
1442
1443def get_node_def_from_graph(node_name, graph_def):
1444  """Returns the `NodeDef` instance for given node name in the graph def.
1445
1446  This method explores only the NodeDefs in `graph_def.node`.
1447
1448  Args:
1449    node_name: Name of the NodeDef to search for.
1450    graph_def: An instance of `GraphDef` proto.
1451
1452  Returns:
1453    the `NodeDef` instance whose name field matches the given node_name or None.
1454  """
1455  for node_def in graph_def.node:
1456    if node_def.name == node_name:
1457      return node_def
1458  return None
1459
1460
1461def set_producer_version(graph, producer_version):
1462  """Sets graph.graph_def_versions.producer to `producer_version`."""
1463  # The C API doesn't expose altering GraphDefVersions. We can indirectly set
1464  # it via import_graph_def though.
1465  graph_def = graph_pb2.GraphDef()
1466  graph_def.versions.producer = producer_version
1467  with graph.as_default():
1468    importer.import_graph_def(graph_def)
1469  assert graph.graph_def_versions.producer, producer_version
1470