1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""XLA tests for pfor.""" 16# pylint: disable=g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.compiler.tf2xla.python import xla as xla_ops 23from tensorflow.python.compiler.xla import jit 24from tensorflow.python.compiler.xla import xla 25from tensorflow.python.eager import context 26from tensorflow.python.eager import def_function 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import control_flow_v2_toggles 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import random_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 36from tensorflow.python.ops.parallel_for.test_util import PForTestCase 37from tensorflow.python.platform import test 38 39 40@test_util.run_all_in_graph_and_eager_modes 41class PForTest(PForTestCase): 42 43 def __init__(self, method_name="runTest"): 44 super(PForTest, self).__init__(method_name) 45 context.context().enable_xla_devices() 46 47 def test_xla_einsum(self): 48 num_loop = 10 49 x_series = random_ops.random_uniform([num_loop, 9, 9]) 50 y_series = random_ops.random_uniform([num_loop, 9, 1]) 51 52 def loop_fn(i): 53 x = array_ops.gather(x_series, 0) # invariant. 54 y = array_ops.gather(y_series, 0) # invariant. 55 x_i = array_ops.gather(x_series, i) 56 y_i = array_ops.gather(y_series, i) 57 z1 = xla_ops.einsum(x_i, y, "ab,bc->ac") 58 z2 = xla_ops.einsum(x, y_i, "ab,bc->ac") 59 z3 = xla_ops.einsum(x, y, "ab,bc->ac") 60 z4 = xla_ops.einsum(x_i, y_i, "ab,bc->ac") 61 z5 = xla_ops.einsum(y_i, x_i, "cd,ce->de") # Includes transpose. 62 outputs = [z1, z2, z3, z4, z5] 63 return outputs 64 65 self._test_loop_fn(loop_fn, num_loop) 66 67 def test_xla(self): 68 69 def compute(x): 70 return math_ops.reduce_mean(x, axis=0, keepdims=True) 71 72 def vectorized_compute(x): 73 return pfor_control_flow_ops.vectorized_map(compute, x) 74 75 result = xla.compile( 76 vectorized_compute, inputs=[array_ops.ones((10, 5, 3))]) 77 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 78 79 def test_function_jit_compile(self): 80 81 def compute(x): 82 return math_ops.reduce_mean(x, axis=0, keepdims=True) 83 84 @def_function.function(jit_compile=True) 85 def vectorized_compute(x): 86 return pfor_control_flow_ops.vectorized_map(compute, x) 87 88 result = vectorized_compute(array_ops.ones((10, 5, 3))) 89 self.run_and_assert_equal(result, array_ops.ones((10, 1, 3))) 90 91 def test_xla_while_loop(self): 92 93 def compute(x): 94 return math_ops.reduce_mean(x, axis=0, keepdims=True) 95 96 def vectorized_compute(x, i): 97 inp = array_ops.gather(x, i) 98 output = pfor_control_flow_ops.vectorized_map(compute, inp) 99 output.set_shape([5, 1]) 100 return output 101 102 def while_compute(x): 103 return control_flow_ops.while_loop_v2( 104 lambda i, _: i < 10, 105 lambda i, y: (i + 1, y + vectorized_compute(x, i)), 106 (0, array_ops.zeros([5, 1])))[1] 107 108 result = xla.compile(while_compute, inputs=[array_ops.ones((10, 5, 3))]) 109 expected = array_ops.ones([5, 1]) * 10 110 self.run_and_assert_equal(expected, result) 111 112 def test_reduce_mean(self): 113 x = random_ops.random_uniform([8, 3]) 114 115 @def_function.function(jit_compile=True) 116 def f(): 117 118 def loop_fn(i, pfor_config): 119 x_i = array_ops.gather(x, i) 120 return x_i - pfor_config.reduce_mean(x_i) 121 122 return pfor_control_flow_ops.pfor(loop_fn, 8) 123 124 output = f() 125 ans = x - math_ops.reduce_mean(x, axis=0) 126 output_val, ans_val = self.evaluate([output, ans]) 127 self.assertAllClose(ans_val, output_val) 128 129 130def _make_unstacked(cond, body, pfor_config): 131 132 def _cond(*args): 133 return math_ops.reduce_any(pfor_config.reduce_concat(args[0])) 134 135 def _body(*args): 136 not_done = args[0] 137 args = args[1:] 138 not_done = math_ops.logical_and(not_done, cond(*args)) 139 outputs = body(*args) 140 return (not_done,) + tuple( 141 array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args)) 142 143 return _cond, _body 144 145 146@test_util.run_all_in_graph_and_eager_modes 147class WhileV2Test(PForTestCase): 148 149 def setUp(self): 150 self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() 151 control_flow_v2_toggles.enable_control_flow_v2() 152 super(WhileV2Test, self).setUp() 153 154 def tearDown(self): 155 if not self._enabled: 156 control_flow_v2_toggles.disable_control_flow_v2() 157 super(WhileV2Test, self).tearDown() 158 159 def _test_loop_fn(self, loop_fn, iters, force_xla=False): 160 161 def f(): 162 return pfor_control_flow_ops.pfor(loop_fn, iters) 163 164 @def_function.function 165 def jit_f(): 166 with jit.experimental_jit_scope(): 167 return f() 168 169 out = f() 170 jit_out = jit_f() 171 self.run_and_assert_equal(out, jit_out) 172 # TODO(agarwal): The following may complain about uncompilable nodes. Hence 173 # these are currently not enabled for all tests. 174 if force_xla: 175 out_exp_compile_f = def_function.function(jit_compile=True)(f)() 176 self.run_and_assert_equal(out, out_exp_compile_f) 177 out_xla_compile_f = xla.compile(f, inputs=[]) 178 self.run_and_assert_equal(out, out_xla_compile_f) 179 180 def test_stateless_while(self): 181 x = random_ops.random_uniform([3, 5]) 182 lengths = constant_op.constant([4, 0, 2]) 183 184 def loop_fn(i): 185 x_i = array_ops.gather(x, i) 186 lengths_i = array_ops.gather(lengths, i) 187 188 return control_flow_ops.while_loop( 189 lambda j, _: j < lengths_i, 190 lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), 191 [0, 0.]) 192 193 self._test_loop_fn(loop_fn, 3) 194 195 def test_while_with_variable(self): 196 v = resource_variable_ops.ResourceVariable(5.) 197 198 def loop_fn(_): 199 _, output = control_flow_ops.while_loop( 200 lambda j, x: j < 4, 201 lambda j, x: (j + 1, x + v), 202 [0, 0.]) 203 return output 204 205 self._test_loop_fn(loop_fn, 3) 206 207 def test_while_unstacked_condition(self): 208 209 def loop_fn(i): 210 return control_flow_ops.while_loop( 211 lambda j, x: j < 4, 212 lambda j, x: (j + 1, x + i), [0, 0]) 213 214 self._test_loop_fn(loop_fn, 3, force_xla=True) 215 216 def test_while_force_unstacked_condition(self): 217 # The while_loop in this setup is similar to the one in test_stateless_while 218 # whose condition is loop variant. However here we wrap the cond and body of 219 # the loop in a way that makes the while_loop condition pfor loop invariant. 220 # This allows xla compilation to work since the vectorized code no longer 221 # needs to perform dynamic partitioning of the inputs. 222 x = random_ops.random_uniform([3, 5]) 223 lengths = constant_op.constant([4, 0, 2]) 224 225 def loop_fn(i, pfor_config): 226 x_i = array_ops.gather(x, i) 227 lengths_i = array_ops.gather(lengths, i) 228 229 def _cond(j, _): 230 return j < lengths_i 231 232 def _body(j, t): 233 return (j + 1, t + array_ops.gather(x_i, j)) 234 235 cond, body = _make_unstacked(_cond, _body, pfor_config) 236 return control_flow_ops.while_loop( 237 cond, 238 body, 239 [True, 0, 0.]) 240 241 self._test_loop_fn(loop_fn, 3, force_xla=True) 242 243 244if __name__ == "__main__": 245 test.main() 246