1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.kernels.functional_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.eager import context 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops import init_ops 31from tensorflow.python.ops import map_fn 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.ops import variables 35import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 36from tensorflow.python.platform import test 37 38 39# pylint: disable=invalid-name 40def simple_scoped_fn(a, x): 41 """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope.""" 42 with variable_scope.variable_scope("body"): 43 # Dummy variable, just to check that scoping works as intended. 44 two = variable_scope.get_variable( 45 "two", [], 46 dtype=dtypes.int32, 47 initializer=init_ops.constant_initializer(2)) 48 return math_ops.multiply(math_ops.add(a, x), two) 49 50 51@test_util.with_control_flow_v2 52class MapFnTest(test.TestCase): 53 54 @test_util.run_in_graph_and_eager_modes 55 def testMap_Simple(self): 56 nums = [1, 2, 3, 4, 5, 6] 57 elems = constant_op.constant(nums, name="data") 58 r = map_fn.map_fn( 59 lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) 60 self.assertAllEqual( 61 np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) 62 63 def testMapDtypeEager(self): 64 with context.eager_mode(): 65 dtype = map_fn.map_fn(lambda x: constant_op.constant(""), 66 constant_op.constant([]), 67 dtype=dtypes.string).dtype 68 self.assertEqual(dtype, dtypes.string) 69 70 def testMapSparseTensor(self): 71 with self.cached_session(): 72 with self.assertRaises(TypeError): 73 map_fn.map_fn( 74 lambda x: x, 75 sparse_tensor.SparseTensor( 76 indices=[[0, 0], [0, 1], [1, 0]], 77 values=constant_op.constant([0, 1, 2]), 78 dense_shape=[2, 2])) 79 80 @test_util.run_in_graph_and_eager_modes 81 def testMapOverScalarErrors(self): 82 with self.assertRaisesRegexp(ValueError, "not scalars"): 83 map_fn.map_fn(lambda x: x, [1, 2]) 84 with self.assertRaisesRegexp(ValueError, "not a scalar"): 85 map_fn.map_fn(lambda x: x, 1) 86 87 @test_util.run_deprecated_v1 88 def testMap_Scoped(self): 89 with self.cached_session() as sess: 90 91 def double_scoped(x): 92 """2x with a dummy 2 that is scoped.""" 93 with variable_scope.variable_scope("body"): 94 # Dummy variable, just to check that scoping works as intended. 95 two = variable_scope.get_variable( 96 "two", [], 97 dtype=dtypes.int32, 98 initializer=init_ops.constant_initializer(2)) 99 return math_ops.multiply(x, two) 100 101 with variable_scope.variable_scope("root") as varscope: 102 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 103 doubles = np.array([2 * x for x in [1, 2, 3, 4, 5, 6]]) 104 105 r = map_fn.map_fn(double_scoped, elems) 106 # Check that we have the one variable we asked for here. 107 self.assertEqual(len(variables.trainable_variables()), 1) 108 self.assertEqual(variables.trainable_variables()[0].name, 109 "root/body/two:0") 110 sess.run([variables.global_variables_initializer()]) 111 self.assertAllEqual(doubles, self.evaluate(r)) 112 113 # Now let's reuse our single variable. 114 varscope.reuse_variables() 115 r = map_fn.map_fn(double_scoped, elems) 116 self.assertEqual(len(variables.trainable_variables()), 1) 117 self.assertAllEqual(doubles, self.evaluate(r)) 118 119 @test_util.run_deprecated_v1 120 def testMap_Grad(self): 121 with self.cached_session(): 122 param = constant_op.constant(2.0) 123 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 124 y = map_fn.map_fn( 125 lambda x: math_ops.multiply(math_ops.square(x), param), elems) 126 r = gradients_impl.gradients(y, param)[0] 127 self.assertAllEqual(91.0, self.evaluate(r)) 128 r = gradients_impl.gradients(y, elems)[0] 129 self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r)) 130 131 @test_util.run_in_graph_and_eager_modes 132 def testMap_SimpleNotTensor(self): 133 nums = np.array([1, 2, 3, 4, 5, 6]) 134 r = map_fn.map_fn( 135 lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums) 136 self.assertAllEqual( 137 np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) 138 139 @test_util.run_in_graph_and_eager_modes 140 def testMap_SingleInputMultiOutput(self): 141 nums = np.array([1, 2, 3, 4, 5, 6]) 142 r = map_fn.map_fn( 143 lambda x: ((x + 3) * 2, -(x + 3) * 2), 144 nums, 145 dtype=(dtypes.int64, dtypes.int64)) 146 self.assertEqual(2, len(r)) 147 self.assertEqual((6,), r[0].get_shape()) 148 self.assertEqual((6,), r[1].get_shape()) 149 received = self.evaluate(r) 150 self.assertAllEqual((nums + 3) * 2, received[0]) 151 self.assertAllEqual(-(nums + 3) * 2, received[1]) 152 153 @test_util.run_in_graph_and_eager_modes 154 def testMap_MultiOutputMismatchedDtype(self): 155 nums = np.array([1, 2, 3, 4, 5, 6]) 156 with self.assertRaisesRegexp( 157 TypeError, r"two structures don't have the same nested structure"): 158 # lambda emits tuple, but dtype is a list 159 map_fn.map_fn( 160 lambda x: ((x + 3) * 2, -(x + 3) * 2), 161 nums, 162 dtype=[dtypes.int64, dtypes.int64]) 163 164 @test_util.run_in_graph_and_eager_modes 165 def testMap_MultiInputSingleOutput(self): 166 nums = np.array([1, 2, 3, 4, 5, 6]) 167 r = map_fn.map_fn( 168 lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)), 169 dtype=dtypes.int64) 170 self.assertEqual((6,), r.get_shape()) 171 received = self.evaluate(r) 172 self.assertAllEqual(nums * nums + (-nums), received) 173 174 @test_util.run_in_graph_and_eager_modes 175 def testMap_MultiInputSameStructureOutput(self): 176 nums = np.array([1, 2, 3, 4, 5, 6]) 177 r = map_fn.map_fn(lambda x: (x[1][0], (x[1][1], x[0])), 178 (nums, (2 * nums, -nums))) 179 r = [r[0], r[1][0], r[1][1]] 180 self.assertEqual((6,), r[0].get_shape()) 181 self.assertEqual((6,), r[1].get_shape()) 182 self.assertEqual((6,), r[2].get_shape()) 183 received = self.evaluate(r) 184 self.assertAllEqual(2 * nums, received[0]) 185 self.assertAllEqual(-nums, received[1]) 186 self.assertAllEqual(nums, received[2]) 187 188 @test_util.run_in_graph_and_eager_modes 189 def testMapShape(self): 190 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 191 y = map_fn.map_fn(lambda e: e, x) 192 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 193 194 @test_util.run_deprecated_v1 195 def testMapUnknownShape(self): 196 x = array_ops.placeholder(dtypes.float32) 197 y = map_fn.map_fn(lambda e: e, x) 198 self.assertIs(None, y.get_shape().dims) 199 200 # TODO(b/124383826): this test fails in eager: the iterable is of length 0 so 201 # so the body of the while loop never executes 202 @test_util.run_v1_only("b/120545219") 203 def testMapEmptyScalar(self): 204 map_return = map_fn.map_fn(lambda x: 1, 205 constant_op.constant([], dtype=dtypes.int32)) 206 self.assertAllEqual([0], map_return.get_shape().dims) 207 self.assertAllEqual([0], self.evaluate(map_return).shape) 208 209 # TODO(b/124383826): this test fails in eager: the iterable is of length 0 so 210 # so the body of the while loop never executes 211 @test_util.run_v1_only("b/120545219") 212 def testMapEmptyTensor(self): 213 with self.cached_session(): 214 map_return = map_fn.map_fn(lambda x: array_ops.zeros([3, 2]), 215 constant_op.constant([])) 216 self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) 217 self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape) 218 219 220if __name__ == "__main__": 221 test.main() 222 223# pylint: enable=invalid-name 224