• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.kernels.functional_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.core.framework import attr_value_pb2
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.python.client import session
26from tensorflow.python.eager import function as eager_function
27from tensorflow.python.data.ops import iterator_ops
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import function
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import functional_ops
36from tensorflow.python.ops import gen_functional_ops
37from tensorflow.python.ops import gradients_impl
38from tensorflow.python.ops import init_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import resource_variable_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables
43import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
44from tensorflow.python.platform import test
45from tensorflow.python.util import compat
46
47
48# pylint: disable=invalid-name
49def simple_scoped_fn(a, x):
50  """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
51  with variable_scope.variable_scope("body"):
52    # Dummy variable, just to check that scoping works as intended.
53    two = variable_scope.get_variable(
54        "two", [],
55        dtype=dtypes.int32,
56        initializer=init_ops.constant_initializer(2))
57    return math_ops.multiply(math_ops.add(a, x), two)
58
59
60@test_util.with_control_flow_v2
61class FunctionalOpsTest(test.TestCase):
62
63  @test_util.run_in_graph_and_eager_modes
64  def testFoldl_Simple(self):
65    elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
66
67    r = functional_ops.foldl(
68        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
69        elems)
70    self.assertAllEqual(208, self.evaluate(r))
71
72    r = functional_ops.foldl(
73        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
74        elems,
75        initializer=10)
76    self.assertAllEqual(880, self.evaluate(r))
77
78  @test_util.run_in_graph_and_eager_modes
79  def testFoldl_SingleInputMultiOutput(self):
80    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
81    initializer = np.array([1, -1.0])
82    r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
83    r_value = self.evaluate(r)
84
85    self.assertAllEqual(22, r_value[0])
86    self.assertAllEqual(20, r_value[1])
87
88  @test_util.run_in_graph_and_eager_modes
89  def testFoldl_MultiInputSingleOutput(self):
90    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
91    initializer = np.array(1.0)
92    r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
93                             initializer)
94    self.assertAllEqual(1, self.evaluate(r))
95
96  @test_util.run_in_graph_and_eager_modes
97  def testFoldl_MultiInputDifferentDimsSingleOutput(self):
98    elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]])
99    other_elems = np.array([-1.0, 1.0])
100    initializer = np.array([0.0, 0.0, 0.0])
101    r = functional_ops.foldl(lambda a, x: a + x[0] * x[1],
102                             (elems, other_elems), initializer)
103    self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r))
104
105  @test_util.run_deprecated_v1
106  def testFoldl_Scoped(self):
107    with self.cached_session() as sess:
108      with variable_scope.variable_scope("root") as varscope:
109        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
110
111        r = functional_ops.foldl(simple_scoped_fn, elems)
112        # Check that we have the one variable we asked for here.
113        self.assertEqual(len(variables.trainable_variables()), 1)
114        self.assertEqual(variables.trainable_variables()[0].name,
115                         "root/body/two:0")
116        sess.run([variables.global_variables_initializer()])
117        self.assertAllEqual(208, self.evaluate(r))
118
119        # Now let's reuse our single variable.
120        varscope.reuse_variables()
121        r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10)
122        self.assertEqual(len(variables.trainable_variables()), 1)
123        self.assertAllEqual(880, self.evaluate(r))
124
125  @test_util.run_in_graph_and_eager_modes
126  def testFoldr_Simple(self):
127    elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
128
129    r = functional_ops.foldr(
130        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
131        elems)
132    self.assertAllEqual(450, self.evaluate(r))
133
134    r = functional_ops.foldr(
135        lambda a, x: math_ops.multiply(math_ops.add(a, x), 2),
136        elems,
137        initializer=10)
138    self.assertAllEqual(1282, self.evaluate(r))
139
140  @test_util.run_in_graph_and_eager_modes
141  def testFoldr_SingleInputMultiOutput(self):
142    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
143    initializer = np.array([1, -1.0])
144    r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
145    r_value = self.evaluate(r)
146
147    self.assertAllEqual(22, r_value[0])
148    self.assertAllEqual(20, r_value[1])
149
150  @test_util.run_in_graph_and_eager_modes
151  def testFoldr_MultiInputSingleOutput(self):
152    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
153    initializer = np.array(1.0)
154    r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
155                             initializer)
156    self.assertAllEqual(1, self.evaluate(r))
157
158  @test_util.run_deprecated_v1
159  def testFoldr_Scoped(self):
160    with self.cached_session() as sess:
161      with variable_scope.variable_scope("root") as varscope:
162        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
163
164        r = functional_ops.foldr(simple_scoped_fn, elems)
165        # Check that we have the one variable we asked for here.
166        self.assertEqual(len(variables.trainable_variables()), 1)
167        self.assertEqual(variables.trainable_variables()[0].name,
168                         "root/body/two:0")
169        sess.run([variables.global_variables_initializer()])
170        self.assertAllEqual(450, self.evaluate(r))
171
172        # Now let's reuse our single variable.
173        varscope.reuse_variables()
174        r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10)
175        self.assertEqual(len(variables.trainable_variables()), 1)
176        self.assertAllEqual(1282, self.evaluate(r))
177
178  # pylint: disable=unnecessary-lambda
179  @test_util.run_deprecated_v1
180  def testFold_Grad(self):
181    with self.cached_session():
182      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
183      v = constant_op.constant(2.0, name="v")
184      r = functional_ops.foldl(
185          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
186      r = gradients_impl.gradients(r, v)[0]
187      self.assertAllEqual(720.0, self.evaluate(r))
188
189      r = functional_ops.foldr(
190          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
191      r = gradients_impl.gradients(r, v)[0]
192      self.assertAllEqual(720.0, self.evaluate(r))
193  # pylint: enable=unnecessary-lambda
194
195  @test_util.run_in_graph_and_eager_modes
196  def testScan_Simple(self):
197    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
198    v = constant_op.constant(2.0, name="v")
199
200    # pylint: disable=unnecessary-lambda
201    r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems)
202    self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r))
203
204    r = functional_ops.scan(
205        lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
206    self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
207    # pylint: enable=unnecessary-lambda
208
209  @test_util.run_in_graph_and_eager_modes
210  def testScan_Reverse(self):
211    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
212    v = constant_op.constant(2.0, name="v")
213
214    # pylint: disable=unnecessary-lambda
215    r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems,
216                            reverse=True)
217    self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r))
218    r = functional_ops.scan(
219        lambda a, x: math_ops.multiply(a, x), elems, initializer=v,
220        reverse=True)
221    self.assertAllEqual([1440., 1440., 720., 240., 60., 12.],
222                        self.evaluate(r))
223    # pylint: enable=unnecessary-lambda
224
225  @test_util.run_in_graph_and_eager_modes
226  def testScan_SingleInputMultiOutput(self):
227    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
228    initializer = (np.array(1.0), np.array(-1.0))
229    r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems,
230                            initializer)
231    r_value = self.evaluate(r)
232
233    self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
234    self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
235
236  @test_util.run_in_graph_and_eager_modes
237  def testScan_MultiInputSingleOutput(self):
238    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
239    initializer = np.array(1.0)
240    # Multiply a * 1 each time
241    r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]),
242                            (elems + 1, -elems), initializer)
243    self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
244
245  @test_util.run_in_graph_and_eager_modes
246  def testScan_MultiInputSameTypeOutput(self):
247    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
248    r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]),
249                            (elems, -elems))
250    r_value = self.evaluate(r)
251    self.assertAllEqual(np.cumsum(elems), r_value[0])
252    self.assertAllEqual(np.cumsum(-elems), r_value[1])
253
254  @test_util.run_in_graph_and_eager_modes
255  def testScan_MultiOutputMismatchedInitializer(self):
256    elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
257    initializer = np.array(1.0)
258    # Multiply a * 1 each time
259    with self.assertRaisesRegexp(
260        ValueError, "two structures don't have the same nested structure"):
261      functional_ops.scan(lambda a, x: (a, -a), elems, initializer)
262
263  @test_util.run_deprecated_v1
264  def testScan_Scoped(self):
265    with self.cached_session() as sess:
266      with variable_scope.variable_scope("root") as varscope:
267        elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
268
269        r = functional_ops.scan(simple_scoped_fn, elems)
270        # Check that we have the one variable we asked for here.
271        self.assertEqual(len(variables.trainable_variables()), 1)
272        self.assertEqual(variables.trainable_variables()[0].name,
273                         "root/body/two:0")
274        sess.run([variables.global_variables_initializer()])
275        results = np.array([1, 6, 18, 44, 98, 208])
276        self.assertAllEqual(results, self.evaluate(r))
277
278        # Now let's reuse our single variable.
279        varscope.reuse_variables()
280        r = functional_ops.scan(simple_scoped_fn, elems, initializer=2)
281        self.assertEqual(len(variables.trainable_variables()), 1)
282        results = np.array([6, 16, 38, 84, 178, 368])
283        self.assertAllEqual(results, self.evaluate(r))
284
285  @test_util.run_in_graph_and_eager_modes
286  def testScanFoldl_Nested(self):
287    elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
288    inner_elems = constant_op.constant([0.5, 0.5], name="data")
289
290    def r_inner(a, x):
291      return functional_ops.foldl(
292          lambda b, y: b * y * x, inner_elems, initializer=a)
293
294    r = functional_ops.scan(r_inner, elems)
295
296    # t == 0 (returns 1)
297    # t == 1, a == 1, x == 2 (returns 1)
298    #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
299    #   t_1 == 1, b == 1,      y == 0.5, returns b * y * x = 1
300    # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
301    #   t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
302    #   t_1 == 1, b == 1.5,    y == 0.5, returns b * y * x = 1.5*1.5
303    # t == 3, a == 2.25, x == 4 (returns 9)
304    #   t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
305    #   t_1 == 1, b == 4.5,       y == 0.5, returns b * y * x = 9
306    self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r))
307
308  @test_util.run_deprecated_v1
309  def testScan_Control(self):
310    with self.cached_session() as sess:
311      s = array_ops.placeholder(dtypes.float32, shape=[None])
312      b = array_ops.placeholder(dtypes.bool)
313
314      with ops.control_dependencies([b]):
315        c = functional_ops.scan(lambda a, x: x * a, s)
316      self.assertAllClose(
317          np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3],
318                                                  b: True}))
319
320  @test_util.run_deprecated_v1
321  def testScan_Grad(self):
322    with self.cached_session():
323      elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
324      v = constant_op.constant(2.0, name="v")
325
326      # pylint: disable=unnecessary-lambda
327      r = functional_ops.scan(
328          lambda a, x: math_ops.multiply(a, x), elems, initializer=v)
329      # pylint: enable=unnecessary-lambda
330      r = gradients_impl.gradients(r, v)[0]
331      self.assertAllEqual(873.0, self.evaluate(r))
332
333  @test_util.run_deprecated_v1
334  def testScanGradientWithPartStopGradient(self):
335    a = variables.Variable(0.0, name="a")
336    b = variables.Variable(0.0, name="b")
337    elems = array_ops.zeros(5)
338    l0, l1 = functional_ops.scan(
339        lambda elem_, input_: (a, b), elems, initializer=(0., 0.))
340    loss = l0 + array_ops.stop_gradient(l1)
341    grad = gradients_impl.gradients(ys=[loss], xs=[a, b])
342    with self.test_session(use_gpu=True) as sess:
343      self.evaluate(variables.global_variables_initializer())
344      self.evaluate(grad)
345
346  @test_util.run_in_graph_and_eager_modes
347  def testFoldShape(self):
348    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
349
350    def fn(_, current_input):
351      return current_input
352
353    initializer = constant_op.constant([0, 0, 0])
354    y = functional_ops.foldl(fn, x, initializer=initializer)
355    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
356
357  @test_util.run_in_graph_and_eager_modes
358  def testScanShape(self):
359    x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
360
361    def fn(_, current_input):
362      return current_input
363
364    initializer = constant_op.constant([0, 0, 0])
365    y = functional_ops.scan(fn, x, initializer=initializer)
366    self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
367
368  # TODO(akshayka): this test fails in eager: the iterable is of length 0 so
369  # so the body of the while loop never executes
370  @test_util.run_deprecated_v1
371  def testScanEmptyTensor(self):
372    with self.cached_session():
373      x = functional_ops.scan(
374          lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4]))
375      self.assertAllEqual([0, 2, 4], x.get_shape())
376      self.assertAllEqual(x.get_shape(), self.evaluate(x).shape)
377
378  @test_util.run_deprecated_v1
379  def testScanUnknownShape(self):
380    x = array_ops.placeholder(dtypes.float32)
381    initializer = array_ops.placeholder(dtypes.float32)
382
383    def fn(_, current_input):
384      return current_input
385
386    y = functional_ops.scan(fn, x, initializer=initializer)
387    self.assertIs(None, y.get_shape().dims)
388
389  @test_util.run_deprecated_v1
390  def testScanVaryingShape(self):
391    with self.cached_session() as sess:
392      x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2])
393      x_t = array_ops.transpose(x)
394      # scan over dimension 0 (with shape None)
395      result = functional_ops.scan(lambda a, x: a + x, x)
396      # scanned over transposed dimension 0 (with shape 2)
397      result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False)
398      # ensure gradients can be calculated
399      result_grad = gradients_impl.gradients(result, [x])[0]
400      result_t_grad = gradients_impl.gradients(result_t, [x_t])[0]
401
402      # smoke test to ensure they all evaluate
403      sess.run([result, result_t, result_grad, result_t_grad],
404               feed_dict={x: [[1.0, 2.0]]})
405
406  @test_util.run_deprecated_v1
407  def testRemoteFunction(self):
408    worker_config = config_pb2.ConfigProto()
409    worker_config.device_count["CPU"] = 2
410    worker, _ = test_util.create_local_cluster(
411        1, 1, worker_config=worker_config)
412
413    @function.Defun(dtypes.int32, dtypes.int32)
414    def _remote_fn(a, b):
415      return math_ops.multiply(a, b)
416
417    with ops.device("/job:ps/task:0"):
418      a = variables.Variable(2, dtype=dtypes.int32)
419      b = variables.Variable(3, dtype=dtypes.int32)
420
421    with ops.device("/job:worker/replica:0/task:0/cpu:0"):
422      remote_op = functional_ops.remote_call(
423          args=[a, b],
424          Tout=[dtypes.int32],
425          f=_remote_fn,
426          target="/job:worker/replica:0/task:0/cpu:1")
427
428    with session.Session(worker[0].target) as sess:
429      self.evaluate(variables.global_variables_initializer())
430      mul = self.evaluate(remote_op)
431      self.assertEqual(mul, [6])
432
433  @test_util.run_deprecated_v1
434  def testRemoteFunctionDirectSession(self):
435    worker_config = config_pb2.ConfigProto()
436    worker_config.device_count["CPU"] = 2
437
438    @function.Defun(dtypes.int32, dtypes.int32)
439    def _remote_fn(a, b):
440      return math_ops.multiply(a, b)
441
442    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
443      a = variables.Variable(2, dtype=dtypes.int32)
444      b = variables.Variable(3, dtype=dtypes.int32)
445
446    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
447      remote_op = functional_ops.remote_call(
448          args=[a, b],
449          Tout=[dtypes.int32],
450          f=_remote_fn,
451          target="/job:localhost/replica:0/task:0/cpu:1")
452
453    with self.test_session(config=worker_config) as sess:
454      self.evaluate(variables.global_variables_initializer())
455      mul = self.evaluate(remote_op)
456      self.assertEqual(mul, [6])
457
458  @test_util.run_deprecated_v1
459  def testRemoteFunctionSameDeviceDirectSession(self):
460
461    @function.Defun(dtypes.int32, dtypes.int32)
462    def _remote_fn(a, b):
463      return math_ops.multiply(a, b)
464
465    with ops.device("/cpu:0"):
466      a = variables.Variable(2, dtype=dtypes.int32)
467      b = variables.Variable(3, dtype=dtypes.int32)
468
469    with ops.device("/cpu:0"):
470      remote_op = functional_ops.remote_call(
471          args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
472
473    with self.cached_session() as sess:
474      self.evaluate(variables.global_variables_initializer())
475      mul = self.evaluate(remote_op)
476      self.assertEqual(mul, [6])
477
478  @test_util.run_deprecated_v1
479  def testRemoteFunctionCPUGPU(self):
480    if not test_util.is_gpu_available():
481      self.skipTest("No GPU available")
482
483    @function.Defun(dtypes.float32, dtypes.float32)
484    def _remote_fn(a, b):
485      return math_ops.multiply(a, b)
486
487    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
488      a = variables.Variable(2, dtype=dtypes.float32)
489      b = variables.Variable(3, dtype=dtypes.float32)
490
491    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
492      remote_op = functional_ops.remote_call(
493          args=[a, b],
494          Tout=[dtypes.float32],
495          f=_remote_fn,
496          target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0
497
498    with self.cached_session() as sess:
499      self.evaluate(variables.global_variables_initializer())
500      mul = self.evaluate(remote_op)
501      self.assertEqual(mul, 9.0)
502
503  @test_util.run_deprecated_v1
504  def testRemoteFunctionGPUCPU(self):
505    if not test_util.is_gpu_available():
506      self.skipTest("No GPU available")
507
508    @function.Defun(dtypes.float32, dtypes.float32)
509    def _remote_fn(a, b):
510      return math_ops.multiply(a, b)
511
512    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
513      a = variables.Variable(2, dtype=dtypes.float32)
514      b = variables.Variable(3, dtype=dtypes.float32)
515
516    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
517      remote_op = functional_ops.remote_call(
518          args=[a, b],
519          Tout=[dtypes.float32],
520          f=_remote_fn,
521          target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0
522
523    with self.cached_session() as sess:
524      self.evaluate(variables.global_variables_initializer())
525      mul = self.evaluate(remote_op)
526      self.assertEqual(mul, 9.0)
527
528  @test_util.run_deprecated_v1
529  def testRemoteFunctionGPUCPUStrings(self):
530    if not test_util.is_gpu_available():
531      self.skipTest("No GPU available")
532
533    @function.Defun(dtypes.string)
534    def _remote_fn(inp):
535      return array_ops.identity(inp)
536
537    a = array_ops.constant("a")
538
539    with ops.device("/gpu:0"):
540      remote_op = functional_ops.remote_call(
541          args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
542
543    with self.cached_session() as sess:
544      ret = self.evaluate(remote_op)
545      self.assertAllEqual(ret, [b"a"])
546
547  @test_util.run_deprecated_v1
548  def testRemoteFunctionCrossProcess(self):
549    workers, _ = test_util.create_local_cluster(2, 1)
550
551    @function.Defun(dtypes.float32, dtypes.float32)
552    def _remote_fn(a, b):
553      return math_ops.multiply(a, b)
554
555    with ops.device("/job:ps/task:0"):
556      a = variables.Variable(2, dtype=dtypes.float32)
557      b = variables.Variable(3, dtype=dtypes.float32)
558
559    with ops.device("/job:worker/replica:0/task:0/cpu:0"):
560      remote_op = functional_ops.remote_call(
561          args=[a, b],
562          Tout=[dtypes.float32],
563          f=_remote_fn,
564          target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0
565
566    with session.Session(workers[0].target) as sess:
567      self.evaluate(variables.global_variables_initializer())
568      mul = self.evaluate(remote_op)
569      self.assertEqual(mul, 9)
570
571  @test_util.run_deprecated_v1
572  def testIf(self):
573
574    @function.Defun(dtypes.float32)
575    def Twice(x):
576      return x * 2
577
578    @function.Defun(dtypes.float32)
579    def Thrice(x):
580      return x * 3 + 1
581
582    with self.test_session(use_gpu=False) as sess:
583
584      x = array_ops.placeholder(dtypes.float32)
585      ret = functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice)[0]
586
587      self.assertAllEqual(sess.run(ret, feed_dict={x: 9.}), 18.)
588      self.assertAllEqual(sess.run(ret, feed_dict={x: -8.}), -23.)
589      self.assertAllEqual(sess.run(ret, feed_dict={x: 0.}), 1.)
590
591  def testWhile(self):
592
593    for use_gpu in (True, False):
594      with ops.Graph().as_default() as g:
595
596        @function.Defun(*[dtypes.float32] * 2)
597        def Cond(n, unused_x):
598          return n > 0
599
600        @function.Defun(*[dtypes.float32] * 2)
601        def Body(n, x):
602          return n - 1, x + n
603
604        def Run(sess, n):
605          return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
606
607        with self.session(graph=g, use_gpu=use_gpu) as sess:
608          self.assertAllEqual(Run(sess, 20.), 210.)
609          self.assertAllEqual(Run(sess, 100.), 5050.)
610
611  # Like above, but using int32 in order to ensure that int32 tensors don't get
612  # copied to the GPU during the application of the while.
613  def testWhileInt32(self):
614    with ops.Graph().as_default() as g:
615
616      @function.Defun(*[dtypes.int32] * 2)
617      def Cond(n, unused_x):
618        return n > 0
619
620      @function.Defun(*[dtypes.int32] * 2)
621      def Body(n, x):
622        return n - 1, x + n
623
624      def Run(sess, n):
625        return sess.run(functional_ops.While([n, 0], Cond, Body))[1]
626
627      with self.session(graph=g, use_gpu=True) as sess:
628        self.assertAllEqual(Run(sess, 20), 210)
629        self.assertAllEqual(Run(sess, 100), 5050)
630
631  @test_util.run_deprecated_v1
632  def testWhileLowering(self):
633
634    def Run(n, fetch_by_name):
635      for use_gpu in (True, False):
636        with ops.Graph().as_default() as g:
637
638          @function.Defun(*[dtypes.float32] * 2)
639          def Cond(n, unused_x):
640            return n > 0
641
642          @function.Defun(*[dtypes.float32] * 2)
643          def Body(n, x):
644            return n - 1, x + n
645
646          # outputs: [0, n*(n+1)/2]
647          outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
648
649          # `outputs` is the list of output tensors of the While op. We
650          # arbitrarily choose the 0th tensor to get the While op and set the
651          # lowering attribute on it.
652          outputs[0].op._set_attr("_lower_using_switch_merge",
653                                  attr_value_pb2.AttrValue(b=True))
654          if not fetch_by_name:
655            fetch = outputs[1]
656          else:
657            fetch = "my_while:1"
658        with self.session(graph=g, use_gpu=use_gpu) as sess:
659          return self.evaluate(fetch)
660
661    self.assertAllEqual(Run(20., False), 210.)
662    self.assertAllEqual(Run(20., True), 210.)
663    self.assertAllEqual(Run(100., False), 5050.)
664    self.assertAllEqual(Run(100., True), 5050.)
665
666  @test_util.run_v1_only("b/120545219")
667  @test_util.disable_xla("b/123337890")  # Different error message
668  def testWhileError(self):
669    for use_gpu in (True, False):
670      with ops.Graph().as_default() as g:
671
672        @function.Defun(*[dtypes.float32] * 2)
673        def Cond(n, unused_x):
674          return n > 0
675
676        @function.Defun(*[dtypes.float32] * 2)
677        def CondReturnsTooManyArgs(n, x):
678          return n > 0, x
679
680        @function.Defun(*[dtypes.float32] * 2)
681        def Body(n, x):
682          return n - 1, x + n
683
684        @function.Defun(*[dtypes.float32] * 2)
685        def BodyReturnsTooManyArgs(n, x):
686          return n - 1, x + n, x
687
688        with self.session(graph=g, use_gpu=use_gpu):
689          with self.assertRaisesRegexp(
690              errors.InvalidArgumentError,
691              "Expected a single scalar.*got 2 tensors."):
692            functional_ops.While([5., 0.], CondReturnsTooManyArgs,
693                                 Body)[0].eval()
694          with self.assertRaisesRegexp(
695              errors.InvalidArgumentError,
696              "While loop body returned 3 arguments. Expected: 2"):
697            functional_ops.While([5., 0.], Cond,
698                                 BodyReturnsTooManyArgs)[0].eval()
699
700  def testWhileInMultipleSubgraphs(self):
701
702    for use_gpu in (True, False):
703      with ops.Graph().as_default() as g:
704
705        @function.Defun(*[dtypes.float32] * 2)
706        def Cond(n, x):  # pylint: disable=unused-argument
707          return n > 0
708
709        @function.Defun(*[dtypes.float32] * 2)
710        def Body(n, x):
711          return n - 1, x + n
712
713        with self.session(graph=g, use_gpu=use_gpu) as sess:
714          n = array_ops.placeholder(dtypes.float32)
715          _, result = functional_ops.While([n, 0.], Cond, Body)
716          c = constant_op.constant(37.)
717
718          self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.}))
719          self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.}))
720          # Test that the result is the same when we run a different subgraph.
721          self.assertAllEqual(5050.,
722                              sess.run([result, c], feed_dict={n: 100.})[0])
723
724  # pylint: disable=cell-var-from-loop
725  def testWhileCapturedInputs(self):
726    for use_gpu in (True, False):
727      with ops.Graph().as_default() as g:
728        v = variables.Variable(1.0)
729
730        def TestCond(n, *args):
731          del args
732          return n < 10
733
734        @function.Defun(*[dtypes.float32] * 2)
735        def TestUnary(n, x):
736          return math_ops.add(n, 1), x + n + v
737
738        @function.Defun(*[dtypes.float32] * 3)
739        def TestBinary(n, x, x2):
740          return math_ops.add(n, 1), x + n + v, x2 + v
741
742        with self.session(graph=g, use_gpu=use_gpu) as sess:
743          result_unary = functional_ops.While(
744              [1.0, 0.],
745              function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary)
746          result_binary = functional_ops.While(
747              [1.0, 0., 0.],
748              function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary)
749          self.evaluate(variables.global_variables_initializer())
750          assert len(result_unary) == 2
751          self.assertEqual([10.0, 54.0], self.evaluate(result_unary))
752          assert len(result_binary) == 3
753          self.assertEqual([10.0, 54.0, 9.0], self.evaluate(result_binary))
754
755          def TestCondCapture(n, *args):
756            del args
757            return math_ops.cast(n, dtypes.float32) + v < 10
758
759          with self.assertRaises(ValueError):
760            _ = functional_ops.While(
761                [1],
762                function.Defun(dtypes.int32)(TestCondCapture),
763                function.Defun(dtypes.int32, dtypes.float32)(TestUnary))
764
765  # pylint: enable=cell-var-from-loop
766
767  def _tfSum(self, use_gpu, rewrite_with_while):
768    with ops.Graph().as_default() as g:
769      with self.session(graph=g, use_gpu=use_gpu) as sess:
770
771        @function.Defun(dtypes.int32, dtypes.float32)
772        def Body(n, x):
773          return x + math_ops.cast(n, dtypes.float32)
774
775        xs = [
776            # 1 + 2  + ... + 20
777            functional_ops.For(
778                1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0],
779            # 100 + 99 + ... + 1
780            functional_ops.For(
781                100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while)
782            [0],
783        ]
784        xvals = self.evaluate(xs)
785      self.assertAllEqual(210, xvals[0])
786      self.assertAllEqual(5050, xvals[1])
787
788  def testFor(self):
789    for use_gpu in (True, False):
790      self._tfSum(use_gpu, False)
791
792  def testForWithWhile(self):
793    for use_gpu in (True, False):
794      self._tfSum(use_gpu, True)
795
796  def testForWithWhileNaming(self):
797    g = ops.Graph()
798    with g.as_default():
799
800      @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody")
801      def TestBody(n, x):
802        return x + math_ops.cast(n, dtypes.float32)
803
804      _ = functional_ops.For(
805          1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0]
806
807    names = []
808    for func in g.as_graph_def().library.function:
809      names.append(func.signature.name)
810    self.assertTrue("TestBody" in names)
811    self.assertTrue("TestBody_Cond" in names)
812    self.assertTrue("TestBody_Body" in names)
813
814  @test_util.run_deprecated_v1
815  def testForCapturedInputs(self):
816    v = variables.Variable(1.0)
817
818    @function.Defun(dtypes.int32)
819    def TestNullary(n):
820      v + math_ops.cast(n, dtypes.float32)  # pylint: disable=expression-not-assigned
821
822    @function.Defun(dtypes.int32, dtypes.float32)
823    def TestUnary(n, x):
824      return x + math_ops.cast(n, dtypes.float32) + v
825
826    @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32)
827    def TestBinary(n, x, x2):
828      return x + math_ops.cast(n, dtypes.float32) + v, x2 + v
829
830    for rewrite_with_while in (True, False):
831      use_gpu = not rewrite_with_while
832      with self.test_session(use_gpu=use_gpu) as sess:
833        result_nullary = functional_ops.For(
834            1, 10, 1, [], TestNullary,
835            rewrite_with_while=rewrite_with_while)
836        result_unary = functional_ops.For(
837            1, 10, 1, [0.], TestUnary,
838            rewrite_with_while=rewrite_with_while)
839        result_binary = functional_ops.For(
840            1, 10, 1, [0., 0.], TestBinary,
841            rewrite_with_while=rewrite_with_while)
842        self.evaluate(variables.global_variables_initializer())
843        assert not result_nullary
844        # The nullary variant doesn't return anything so we can't easily run it.
845        # As a total hack, fetch the operation by name and run it.
846        sess.run(ops.get_default_graph().get_operation_by_name(
847            "While" if rewrite_with_while else "For"))
848        assert len(result_unary) == 1
849        self.assertEqual([54.0], self.evaluate(result_unary))
850        assert len(result_binary) == 2
851        self.assertEqual([54.0, 9.0], self.evaluate(result_binary))
852
853  def _tfMLP(self, xval, wsval, bsval, rewrite_with_while):
854    # On GPU, don't rewrite using a while loop.
855    use_gpu = not rewrite_with_while
856    with self.test_session(use_gpu=use_gpu):
857
858      @function.Defun(dtypes.int32, *[dtypes.float64] * 3)
859      def MLP(i, a, ws, bs):
860        a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :])
861        return a, ws, bs
862
863      ret = functional_ops.For(
864          0,
865          wsval.shape[0],
866          1, [xval, wsval, bsval],
867          MLP,
868          rewrite_with_while=rewrite_with_while)[0]
869
870      return self.evaluate(ret)
871
872  def _npMLP(self, xval, wsval, bsval):
873    for i in range(wsval.shape[0]):
874      xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :])
875    return xval
876
877  def _testForMLP(self, rewrite_with_while):
878    # We construct a 5-layer Multi-Layer Perceptron network here.
879    # Each layer have the same number of hidden unites (3), and the
880    # activation function is tanh().  We feed the input (xval) with
881    # batch size 2.
882    xval = np.random.normal(size=(2, 3))
883    wsval = np.random.normal(size=(5, 3, 3))
884    bsval = np.random.normal(size=(5, 3))
885    np_ans = self._npMLP(xval, wsval, bsval)
886    tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while)
887    self.assertAllClose(np_ans, tf_for_ans)
888
889  @test_util.run_deprecated_v1
890  def testForMLP(self):
891    self._testForMLP(False)
892
893  @test_util.run_deprecated_v1
894  def testForMLPWhile(self):
895    self._testForMLP(True)
896
897  @test_util.run_v1_only("b/120545219")
898  def testForError(self):
899
900    @function.Defun(dtypes.int32, dtypes.float32)
901    def Foo(i, v):
902      return math_ops.cast(i, dtypes.float32) + v
903
904    @function.Defun(dtypes.int32, dtypes.float32)
905    def ReturnsTooManyArgs(unused_i, v):
906      return v, v
907
908    with self.test_session(use_gpu=True):
909      with self.assertRaisesRegexp(errors.InvalidArgumentError,
910                                   "must be a scalar"):
911        functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval()
912      with self.assertRaisesRegexp(errors.InvalidArgumentError,
913                                   "Invalid start/limit/delta"):
914        functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval()
915      with self.assertRaisesRegexp(
916          errors.InvalidArgumentError,
917          "For loop body returned 2 arguments. Expected: 1"):
918        functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval()
919
920  @test_util.run_deprecated_v1
921  def testGradient(self):
922
923    @function.Defun(dtypes.float32)
924    def Poly(x):
925      # y = 2x^3+3x^2+4x+8
926      return 2 * x * x * x + 3 * x * x + 4 * x + 8
927
928    @function.Defun(dtypes.float32)
929    def Grad(x):
930      # dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4)
931      return functional_ops.Gradient([x, 1.0], Poly)[0]
932
933    with self.test_session(use_gpu=False) as sess:
934      a = constant_op.constant(0.)
935      avals = [Poly(a), Grad(a)]
936      b = constant_op.constant(1.)
937      bvals = [Poly(b), Grad(b)]
938      self.assertAllEqual(self.evaluate(avals), [8., 4.])
939      self.assertAllEqual(self.evaluate(bvals), [17., 16.])
940
941
942# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the
943# below test cases.
944class PartitionedCallTest(test.TestCase):
945
946  @test_util.run_deprecated_v1
947  def testBasicSingleDevice(self):
948
949    @function.Defun(*[dtypes.float32] * 2)
950    def Body(x, y):
951      with ops.device("/cpu:0"):
952        a = x + x
953        b = y + y
954        return a + b
955
956    output, = self.evaluate(
957        functional_ops.partitioned_call(
958            args=[constant_op.constant(1.),
959                  constant_op.constant(2.)], f=Body))
960    self.assertEqual(output, 6.)
961
962  @test_util.run_deprecated_v1
963  def testBasicMultiDevice(self):
964    config = config_pb2.ConfigProto(device_count={"CPU": 3})
965
966    @function.Defun(*[dtypes.float32] * 2)
967    def Body(x, y):
968      # if x = 1, y = 2, ...
969      with ops.device("/cpu:0"):
970        # a:= 1 + 1 = 2
971        a = x + x
972      with ops.device("/cpu:1"):
973        # b:= 2 + 2 = 4
974        b = a + y
975      with ops.device("/cpu:2"):
976        # c:= 2 + 4 = 6
977        c = a + b
978      # a + b + c = 2 + 4 + 6 = 12
979      return a + b + c
980
981    with self.test_session(config=config):
982      output, = functional_ops.partitioned_call(
983          args=[constant_op.constant(1.),
984                constant_op.constant(2.)], f=Body)
985      self.assertEqual(output.eval(), 12.)
986
987  @test_util.run_deprecated_v1
988  def testBasicMultiDeviceGPU(self):
989    if not test_util.is_gpu_available():
990      return
991
992    @function.Defun(*[dtypes.float32] * 2)
993    def Body(x, y):
994      with ops.device("/gpu:0"):
995        a = x + x
996        b = y + y
997      with ops.device("/cpu:0"):
998        c = a + b
999        return c
1000
1001    output, = self.evaluate(
1002        functional_ops.partitioned_call(
1003            args=[constant_op.constant(1.),
1004                  constant_op.constant(2.)], f=Body))
1005    self.assertEqual(output, 6.)
1006
1007  @test_util.run_deprecated_v1
1008  def testBasicNoDeviceAnnotations(self):
1009
1010    @function.Defun(*[dtypes.float32] * 2)
1011    def Body(x, y):
1012      a = x + x
1013      b = y + y
1014      return a + b
1015
1016    output, = self.evaluate(
1017        functional_ops.partitioned_call(
1018            args=[constant_op.constant(1.),
1019                  constant_op.constant(2.)], f=Body))
1020    self.assertEqual(output, 6.)
1021
1022  @test_util.run_deprecated_v1
1023  def testShardsRunOnRequestedDevices(self):
1024    config = config_pb2.ConfigProto(device_count={"CPU": 4})
1025
1026    @function.Defun()
1027    def Body():
1028      # Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on
1029      # which the resource was created, so that we can verify that ops were
1030      # actually run on the requested devices.
1031      #
1032      # TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the
1033      # name of the device on which a resource lives / for determining the
1034      # device on which an op ran.
1035      with ops.device("/cpu:0"):
1036        s1 = iterator_ops.Iterator.from_structure(
1037            (dtypes.float32,)).string_handle()
1038      with ops.device("/cpu:1"):
1039        s2 = iterator_ops.Iterator.from_structure(
1040            (dtypes.float32,)).string_handle()
1041      with ops.device("/cpu:2"):
1042        s3 = iterator_ops.Iterator.from_structure(
1043            (dtypes.float32,)).string_handle()
1044      return s1, s2, s3
1045
1046    with self.test_session(config=config, use_gpu=True) as sess:
1047      outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
1048    self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
1049    self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
1050    self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
1051
1052  @test_util.run_deprecated_v1
1053  def testAssignAddResourceVariable(self):
1054
1055    v = resource_variable_ops.ResourceVariable(1.0)
1056
1057    @function.Defun()
1058    def AssignAdd():
1059      v.assign_add(1.0)
1060
1061    op = functional_ops.partitioned_call(
1062        args=AssignAdd.captured_inputs, f=AssignAdd)
1063    _ = self.evaluate(variables.global_variables_initializer())
1064    _ = self.evaluate(op)
1065    value = self.evaluate(v.read_value())
1066    self.assertEqual(value, 2.0)
1067
1068  @test_util.run_deprecated_v1
1069  def testFunctionWithResourcesOnDifferentDevices(self):
1070    if not test_util.is_gpu_available():
1071      self.skipTest("No GPUs available.")
1072
1073    with ops.device("/cpu:0"):
1074      v_cpu_zero = resource_variable_ops.ResourceVariable(
1075          [0.0, 1.0, 2.0], name="v_cpu_zero")
1076
1077    with ops.device("/cpu:1"):
1078      v_cpu_one = resource_variable_ops.ResourceVariable(
1079          [0.0, 1.0, 2.0], name="v_cpu_one")
1080
1081    with ops.device("/gpu:0"):
1082      v_gpu = resource_variable_ops.ResourceVariable(
1083          [0.0, 1.0, 2.0], name="v_gpu")
1084
1085    def sum_gather():
1086      cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2]))
1087      also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2]))
1088      gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
1089      return cpu_result, also_cpu_result, gpu_result
1090
1091    defined = function.Defun()(sum_gather)
1092    with self.test_session(
1093        config=config_pb2.ConfigProto(
1094            allow_soft_placement=False,
1095            log_device_placement=True,
1096            device_count={"CPU": 2})) as sess:
1097      self.evaluate(variables.global_variables_initializer())
1098      expected = self.evaluate(sum_gather())
1099      result = sess.run(
1100          functional_ops.partitioned_call(
1101              args=defined.captured_inputs, f=defined))
1102      self.assertAllEqual(expected, result)
1103
1104  # Use an invalid executor name to test the plumbing of the executor_type attr.
1105  @test_util.run_v1_only("b/120545219")
1106  def testExecutorTypeAttrExecutorNotFound(self):
1107    @function.Defun(dtypes.int32)
1108    def AddFive(x):
1109      return x + 5
1110
1111    op = functional_ops.partitioned_call(
1112        args=[constant_op.constant([1, 2, 3], dtype=dtypes.int32)],
1113        f=AddFive,
1114        executor_type="NON_EXISTENT_EXECUTOR")
1115    with self.assertRaisesRegexp(errors.NotFoundError,
1116                                 "NON_EXISTENT_EXECUTOR"):
1117      self.evaluate(op)
1118
1119
1120@test_util.run_all_in_graph_and_eager_modes
1121@test_util.with_control_flow_v2
1122class FunctionalOpsCaseTest(test.TestCase):
1123
1124  def testCase(self):
1125    @eager_function.defun
1126    def two(x):
1127      return x * 2
1128
1129    @eager_function.defun
1130    def three(x):
1131      return x * 3
1132
1133    @eager_function.defun
1134    def four(x):
1135      return x * 4
1136
1137    def f(branch, x):
1138      tmpl = array_ops.zeros_like(x)
1139      return array_ops.identity(gen_functional_ops.case(
1140          branch, input=[x], Tout=[dtypes.float32],
1141          branches=[f.get_concrete_function(tmpl)
1142                    for f in (two, three, four)])[0])
1143    one = array_ops.ones([])
1144    self.assertAllEqual(np.float32(2), self.evaluate(f(0, one)))
1145    self.assertAllEqual(np.float32(3), self.evaluate(f(1, one)))
1146    self.assertAllEqual(np.float32(4), self.evaluate(f(2, one)))
1147    self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one)))  # <0 default
1148    self.assertAllEqual(np.float32(4), self.evaluate(f(6, one)))  # >=N default
1149
1150
1151if __name__ == "__main__":
1152  test.main()
1153
1154# pylint: enable=invalid-name
1155