1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utility.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.ops import variables 24from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 25from tensorflow.python.platform import test 26from tensorflow.python.util import nest 27 28 29class PForTestCase(test.TestCase): 30 """Base class for test cases.""" 31 32 def _run_targets(self, targets1, targets2=None, run_init=True): 33 targets1 = nest.flatten(targets1) 34 targets2 = ([] if targets2 is None else nest.flatten(targets2)) 35 assert len(targets1) == len(targets2) or not targets2 36 if run_init: 37 init = variables.global_variables_initializer() 38 self.evaluate(init) 39 return self.evaluate(targets1 + targets2) 40 41 # TODO(agarwal): Allow tests to pass down tolerances. 42 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 43 outputs = self._run_targets(targets1, targets2) 44 outputs = nest.flatten(outputs) # flatten SparseTensorValues 45 n = len(outputs) // 2 46 for i in range(n): 47 if outputs[i + n].dtype != np.object: 48 self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol) 49 else: 50 self.assertAllEqual(outputs[i + n], outputs[i]) 51 52 def _test_loop_fn(self, 53 loop_fn, 54 iters, 55 parallel_iterations=None, 56 fallback_to_while_loop=False, 57 rtol=1e-4, 58 atol=1e-5): 59 t1 = pfor_control_flow_ops.pfor( 60 loop_fn, 61 iters=iters, 62 fallback_to_while_loop=fallback_to_while_loop, 63 parallel_iterations=parallel_iterations) 64 loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1) 65 t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters, 66 parallel_iterations=parallel_iterations) 67 self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol) 68