• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.kernels.functional_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gradients_impl
30from tensorflow.python.ops import init_ops
31from tensorflow.python.ops import map_fn
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import variable_scope
34from tensorflow.python.ops import variables
35import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
36from tensorflow.python.platform import test
37
38
39# pylint: disable=invalid-name
40def simple_scoped_fn(a, x):
41  """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
42  with variable_scope.variable_scope("body"):
43    # Dummy variable, just to check that scoping works as intended.
44    two = variable_scope.get_variable(
45        "two", [],
46        dtype=dtypes.int32,
47        initializer=init_ops.constant_initializer(2))
48    return math_ops.multiply(math_ops.add(a, x), two)
49
50
51@test_util.with_control_flow_v2
52class MapFnTest(test.TestCase):
53
54  @test_util.run_in_graph_and_eager_modes
55  def testMap_Simple(self):
56    nums = [1, 2, 3, 4, 5, 6]
57    elems = constant_op.constant(nums, name="data")
58    r = map_fn.map_fn(
59        lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems)
60    self.assertAllEqual(
61        np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
62
63  def testMapDtypeEager(self):
64    with context.eager_mode():
65      dtype = map_fn.map_fn(lambda x: constant_op.constant(""),
66                            constant_op.constant([]),
67                            dtype=dtypes.string).dtype
68      self.assertEqual(dtype, dtypes.string)
69
70  def testMapSparseTensor(self):
71    with self.cached_session():
72      with self.assertRaises(TypeError):
73        map_fn.map_fn(
74            lambda x: x,
75            sparse_tensor.SparseTensor(
76                indices=[[0, 0], [0, 1], [1, 0]],
77                values=constant_op.constant([0, 1, 2]),
78                dense_shape=[2, 2]))
79
80  @test_util.run_in_graph_and_eager_modes
81  def testMapOverScalarErrors(self):
82    with self.assertRaisesRegexp(ValueError, "not scalars"):
83      map_fn.map_fn(lambda x: x, [1, 2])
84    with self.assertRaisesRegexp(ValueError, "not a scalar"):
85      map_fn.map_fn(lambda x: x, 1)
86
87  @test_util.run_deprecated_v1
88  def testMap_Scoped(self):
89    with self.cached_session() as sess:
90
91      def double_scoped(x):
92        """2x with a dummy 2 that is scoped."""
93        with variable_scope.variable_scope("body"):
94          # Dummy variable, just to check that scoping works as intended.
95          two = variable_scope.get_variable(
96              "two", [],
97              dtype=dtypes.int32,
98              initializer=init_ops.constant_initializer(2))
99          return math_ops.multiply(x, two)
100
101      with variable_scope.variable_scope("root") as varscope:
102        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
103        doubles = np.array([2 * x for x in [1, 2, 3, 4, 5, 6]])
104
105        r = map_fn.map_fn(double_scoped, elems)
106        # Check that we have the one variable we asked for here.
107        self.assertEqual(len(variables.trainable_variables()), 1)
108        self.assertEqual(variables.trainable_variables()[0].name,
109                         "root/body/two:0")
110        sess.run([variables.global_variables_initializer()])
111        self.assertAllEqual(doubles, self.evaluate(r))
112
113        # Now let's reuse our single variable.
114        varscope.reuse_variables()
115        r = map_fn.map_fn(double_scoped, elems)
116        self.assertEqual(len(variables.trainable_variables()), 1)
117        self.assertAllEqual(doubles, self.evaluate(r))
118
119  @test_util.run_deprecated_v1
120  def testMap_Grad(self):
121    with self.cached_session():
122      param = constant_op.constant(2.0)
123      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
124      y = map_fn.map_fn(
125          lambda x: math_ops.multiply(math_ops.square(x), param), elems)
126      r = gradients_impl.gradients(y, param)[0]
127      self.assertAllEqual(91.0, self.evaluate(r))
128      r = gradients_impl.gradients(y, elems)[0]
129      self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
130
131  @test_util.run_in_graph_and_eager_modes
132  def testMap_SimpleNotTensor(self):
133    nums = np.array([1, 2, 3, 4, 5, 6])
134    r = map_fn.map_fn(
135        lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums)
136    self.assertAllEqual(
137        np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
138
139  @test_util.run_in_graph_and_eager_modes
140  def testMap_SingleInputMultiOutput(self):
141    nums = np.array([1, 2, 3, 4, 5, 6])
142    r = map_fn.map_fn(
143        lambda x: ((x + 3) * 2, -(x + 3) * 2),
144        nums,
145        dtype=(dtypes.int64, dtypes.int64))
146    self.assertEqual(2, len(r))
147    self.assertEqual((6,), r[0].get_shape())
148    self.assertEqual((6,), r[1].get_shape())
149    received = self.evaluate(r)
150    self.assertAllEqual((nums + 3) * 2, received[0])
151    self.assertAllEqual(-(nums + 3) * 2, received[1])
152
153  @test_util.run_in_graph_and_eager_modes
154  def testMap_MultiOutputMismatchedDtype(self):
155    nums = np.array([1, 2, 3, 4, 5, 6])
156    with self.assertRaisesRegexp(
157        TypeError, r"two structures don't have the same nested structure"):
158      # lambda emits tuple, but dtype is a list
159      map_fn.map_fn(
160          lambda x: ((x + 3) * 2, -(x + 3) * 2),
161          nums,
162          dtype=[dtypes.int64, dtypes.int64])
163
164  @test_util.run_in_graph_and_eager_modes
165  def testMap_MultiInputSingleOutput(self):
166    nums = np.array([1, 2, 3, 4, 5, 6])
167    r = map_fn.map_fn(
168        lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)),
169        dtype=dtypes.int64)
170    self.assertEqual((6,), r.get_shape())
171    received = self.evaluate(r)
172    self.assertAllEqual(nums * nums + (-nums), received)
173
174  @test_util.run_in_graph_and_eager_modes
175  def testMap_MultiInputSameStructureOutput(self):
176    nums = np.array([1, 2, 3, 4, 5, 6])
177    r = map_fn.map_fn(lambda x: (x[1][0], (x[1][1], x[0])),
178                      (nums, (2 * nums, -nums)))
179    r = [r[0], r[1][0], r[1][1]]
180    self.assertEqual((6,), r[0].get_shape())
181    self.assertEqual((6,), r[1].get_shape())
182    self.assertEqual((6,), r[2].get_shape())
183    received = self.evaluate(r)
184    self.assertAllEqual(2 * nums, received[0])
185    self.assertAllEqual(-nums, received[1])
186    self.assertAllEqual(nums, received[2])
187
188  @test_util.run_in_graph_and_eager_modes
189  def testMapShape(self):
190    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
191    y = map_fn.map_fn(lambda e: e, x)
192    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
193
194  @test_util.run_deprecated_v1
195  def testMapUnknownShape(self):
196    x = array_ops.placeholder(dtypes.float32)
197    y = map_fn.map_fn(lambda e: e, x)
198    self.assertIs(None, y.get_shape().dims)
199
200  # TODO(b/124383826): this test fails in eager: the iterable is of length 0 so
201  # so the body of the while loop never executes
202  @test_util.run_v1_only("b/120545219")
203  def testMapEmptyScalar(self):
204    map_return = map_fn.map_fn(lambda x: 1,
205                               constant_op.constant([], dtype=dtypes.int32))
206    self.assertAllEqual([0], map_return.get_shape().dims)
207    self.assertAllEqual([0], self.evaluate(map_return).shape)
208
209  # TODO(b/124383826): this test fails in eager: the iterable is of length 0 so
210  # so the body of the while loop never executes
211  @test_util.run_v1_only("b/120545219")
212  def testMapEmptyTensor(self):
213    with self.cached_session():
214      map_return = map_fn.map_fn(lambda x: array_ops.zeros([3, 2]),
215                                 constant_op.constant([]))
216      self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
217      self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
218
219
220if __name__ == "__main__":
221  test.main()
222
223# pylint: enable=invalid-name
224