# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for tf.cond in XLA.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.compiler.tests import xla_test from tensorflow.python.client import session from tensorflow.python.compiler.xla import xla from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @test_util.with_control_flow_v2 class CondTest(xla_test.XLATestCase): def testCondAndTensorArrayInDefun(self): # TODO(b/132430685): Make test more useful. Also b/129396295, b/127846988 with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @function.defun def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond( constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t = f() self.assertAllEqual([5.], self.evaluate(output_t)) xla_context.Exit() def testCondAndTensorArrayInDefun_constFolding(self): g = ops.Graph() with session.Session(graph=g), g.as_default(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @function.defun def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond( constant_op.constant(False), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t = f() self.assertAllEqual([10.], self.evaluate(output_t)) xla_context.Exit() def testCondAndTensorArray_xlaCompile(self): self.skipTest("b/127846988") # Fails with "Uninitialized arguments" in XlaIfOp::Compile with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.cond( constant_op.constant(True), lambda: ta.write(0, 5.), lambda: ta.write(0, 10.)) return output.stack() output_t, = xla.compile(f) self.assertAllEqual([5.], self.evaluate(output_t)) xla_context.Exit() def testCondConstPropagation(self): with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.int32) # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] def if_false(): return 5. output = control_flow_ops.cond( constant_op.constant(True), if_true, if_false) self.assertAllEqual(1., sess.run(output, feed_dict={ x: [0., 1., 2.], p: 1 })) xla_context.Exit() def testCondConstPropagation_xlaCompile(self): self.skipTest("b/132430685") with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder_with_default([0., 1., 2.], shape=[3]) p = constant_op.constant(1) def f(): # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] def if_false(): return 5. return control_flow_ops.cond( constant_op.constant(True), if_true, if_false) output = xla.compile(f) self.assertAllEqual(1., self.evaluate(output)) xla_context.Exit() def testCondConstPropagation_errorMsg(self): self.skipTest("b/132430685") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[:p] def if_false(): return array_ops.fill([p], 5.) output = control_flow_ops.cond( constant_op.constant(True), if_true, if_false) with self.assertRaisesRegex(errors.InvalidArgumentError, "must be a compile-time constant"): sess.run( output, feed_dict={ x: [0., 1., 2.], }) xla_context.Exit() def testCondConstPropagation_errorMsg_xlaCompile(self): with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = random_ops.random_uniform([], minval=1, maxval=3, dtype=dtypes.int32) condition = math_ops.cast( random_ops.random_uniform([], minval=0, maxval=2, dtype=dtypes.int32), dtypes.bool) def f(): # TODO(b/129021699): Wrapping this in a tf.function does not work. def if_true(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[:p] def if_false(): return array_ops.fill([p], 5.) return control_flow_ops.cond(condition, if_true, if_false) output = xla.compile(f) with self.assertRaisesRegex(errors.InvalidArgumentError, "must be a compile-time constant"): sess.run( output, feed_dict={ x: [0., 1., 2.], }) xla_context.Exit() def testSwitchCaseAndTensorArrayInDefun(self): self.skipTest("b/127846988") with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() @function.defun def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.switch_case( constant_op.constant(1), { 0: lambda: ta.write(0, 5.), 1: lambda: ta.write(0, 10.), 2: lambda: ta.write(0, 15.), }) return output.stack() output_t = f() self.assertAllEqual([10.], self.evaluate(output_t)) xla_context.Exit() def testSwitchCaseAndTensorArray_xlaCompile(self): self.skipTest("b/127846988") with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() def f(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=1) output = control_flow_ops.switch_case( constant_op.constant(1), { 0: lambda: ta.write(0, 5.), 1: lambda: ta.write(0, 10.), 2: lambda: ta.write(0, 15.), }) return output.stack() output_t, = xla.compile(f) self.assertAllEqual([10.], self.evaluate(output_t)) xla_context.Exit() def testSwitchCaseConstPropagation(self): self.skipTest("b/127846988") with self.session() as sess, self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() x = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.int32) def branch0(): return 5. def branch1(): return 15. # TODO(b/129021699): Wrapping this in a tf.function does not work. def branch2(): # This emits a StridedSlice op which expects the index to be a # compile-time const. return x[p] output = control_flow_ops.switch_case( constant_op.constant(2), { 0: branch0, 1: branch1, 2: branch2, }) self.assertAllEqual(7., sess.run(output, feed_dict={ x: [0., 1., 7.], p: 2, })) xla_context.Exit() def testCondNoInputs(self): """Verifies against `Failed precondition: Expected one input shape`.""" with self.session(), self.test_scope(): xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() for pred in True, False: cond_out = control_flow_ops.cond( array_ops.placeholder_with_default(pred, []), lambda: constant_op.constant(2.), lambda: constant_op.constant(1.)) self.assertEqual(int(pred) + 1., self.evaluate(cond_out)) xla_context.Exit() if __name__ == '__main__': test.main()