1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utility.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.ops import variables 25from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 26from tensorflow.python.platform import test 27from tensorflow.python.util import nest 28 29 30class PForTestCase(test.TestCase): 31 """Base class for test cases.""" 32 33 def _run_targets(self, targets1, targets2=None, run_init=True): 34 targets1 = nest.flatten(targets1) 35 targets2 = ([] if targets2 is None else nest.flatten(targets2)) 36 assert len(targets1) == len(targets2) or not targets2 37 if run_init: 38 init = variables.global_variables_initializer() 39 self.evaluate(init) 40 return self.evaluate(targets1 + targets2) 41 42 def run_and_assert_equal(self, targets1, targets2): 43 outputs = self._run_targets(targets1, targets2) 44 outputs = nest.flatten(outputs) # flatten SparseTensorValues 45 n = len(outputs) // 2 46 for i in range(n): 47 if outputs[i + n].dtype != np.object: 48 self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-5) 49 else: 50 self.assertAllEqual(outputs[i + n], outputs[i]) 51 52 def _test_loop_fn(self, loop_fn, iters, 53 loop_fn_dtypes=dtypes.float32, 54 parallel_iterations=None): 55 t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters, 56 parallel_iterations=parallel_iterations) 57 t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters, 58 parallel_iterations=parallel_iterations) 59 self.run_and_assert_equal(t1, t2) 60