1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.kernels.functional_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.core.framework import attr_value_pb2 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.python.client import session 26from tensorflow.python.eager import function as eager_function 27from tensorflow.python.data.ops import iterator_ops 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import function 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import functional_ops 36from tensorflow.python.ops import gen_functional_ops 37from tensorflow.python.ops import gradients_impl 38from tensorflow.python.ops import init_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import resource_variable_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables 43import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 44from tensorflow.python.platform import test 45from tensorflow.python.util import compat 46 47 48# pylint: disable=invalid-name 49def simple_scoped_fn(a, x): 50 """Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope.""" 51 with variable_scope.variable_scope("body"): 52 # Dummy variable, just to check that scoping works as intended. 53 two = variable_scope.get_variable( 54 "two", [], 55 dtype=dtypes.int32, 56 initializer=init_ops.constant_initializer(2)) 57 return math_ops.multiply(math_ops.add(a, x), two) 58 59 60@test_util.with_control_flow_v2 61class FunctionalOpsTest(test.TestCase): 62 63 @test_util.run_in_graph_and_eager_modes 64 def testFoldl_Simple(self): 65 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 66 67 r = functional_ops.foldl( 68 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 69 elems) 70 self.assertAllEqual(208, self.evaluate(r)) 71 72 r = functional_ops.foldl( 73 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 74 elems, 75 initializer=10) 76 self.assertAllEqual(880, self.evaluate(r)) 77 78 @test_util.run_in_graph_and_eager_modes 79 def testFoldl_SingleInputMultiOutput(self): 80 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 81 initializer = np.array([1, -1.0]) 82 r = functional_ops.foldl(lambda a, x: a + x, elems, initializer) 83 r_value = self.evaluate(r) 84 85 self.assertAllEqual(22, r_value[0]) 86 self.assertAllEqual(20, r_value[1]) 87 88 @test_util.run_in_graph_and_eager_modes 89 def testFoldl_MultiInputSingleOutput(self): 90 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 91 initializer = np.array(1.0) 92 r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems), 93 initializer) 94 self.assertAllEqual(1, self.evaluate(r)) 95 96 @test_util.run_in_graph_and_eager_modes 97 def testFoldl_MultiInputDifferentDimsSingleOutput(self): 98 elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]]) 99 other_elems = np.array([-1.0, 1.0]) 100 initializer = np.array([0.0, 0.0, 0.0]) 101 r = functional_ops.foldl(lambda a, x: a + x[0] * x[1], 102 (elems, other_elems), initializer) 103 self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r)) 104 105 @test_util.run_deprecated_v1 106 def testFoldl_Scoped(self): 107 with self.cached_session() as sess: 108 with variable_scope.variable_scope("root") as varscope: 109 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 110 111 r = functional_ops.foldl(simple_scoped_fn, elems) 112 # Check that we have the one variable we asked for here. 113 self.assertEqual(len(variables.trainable_variables()), 1) 114 self.assertEqual(variables.trainable_variables()[0].name, 115 "root/body/two:0") 116 sess.run([variables.global_variables_initializer()]) 117 self.assertAllEqual(208, self.evaluate(r)) 118 119 # Now let's reuse our single variable. 120 varscope.reuse_variables() 121 r = functional_ops.foldl(simple_scoped_fn, elems, initializer=10) 122 self.assertEqual(len(variables.trainable_variables()), 1) 123 self.assertAllEqual(880, self.evaluate(r)) 124 125 @test_util.run_in_graph_and_eager_modes 126 def testFoldr_Simple(self): 127 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 128 129 r = functional_ops.foldr( 130 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 131 elems) 132 self.assertAllEqual(450, self.evaluate(r)) 133 134 r = functional_ops.foldr( 135 lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), 136 elems, 137 initializer=10) 138 self.assertAllEqual(1282, self.evaluate(r)) 139 140 @test_util.run_in_graph_and_eager_modes 141 def testFoldr_SingleInputMultiOutput(self): 142 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 143 initializer = np.array([1, -1.0]) 144 r = functional_ops.foldr(lambda a, x: a + x, elems, initializer) 145 r_value = self.evaluate(r) 146 147 self.assertAllEqual(22, r_value[0]) 148 self.assertAllEqual(20, r_value[1]) 149 150 @test_util.run_in_graph_and_eager_modes 151 def testFoldr_MultiInputSingleOutput(self): 152 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 153 initializer = np.array(1.0) 154 r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems), 155 initializer) 156 self.assertAllEqual(1, self.evaluate(r)) 157 158 @test_util.run_deprecated_v1 159 def testFoldr_Scoped(self): 160 with self.cached_session() as sess: 161 with variable_scope.variable_scope("root") as varscope: 162 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 163 164 r = functional_ops.foldr(simple_scoped_fn, elems) 165 # Check that we have the one variable we asked for here. 166 self.assertEqual(len(variables.trainable_variables()), 1) 167 self.assertEqual(variables.trainable_variables()[0].name, 168 "root/body/two:0") 169 sess.run([variables.global_variables_initializer()]) 170 self.assertAllEqual(450, self.evaluate(r)) 171 172 # Now let's reuse our single variable. 173 varscope.reuse_variables() 174 r = functional_ops.foldr(simple_scoped_fn, elems, initializer=10) 175 self.assertEqual(len(variables.trainable_variables()), 1) 176 self.assertAllEqual(1282, self.evaluate(r)) 177 178 # pylint: disable=unnecessary-lambda 179 @test_util.run_deprecated_v1 180 def testFold_Grad(self): 181 with self.cached_session(): 182 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 183 v = constant_op.constant(2.0, name="v") 184 r = functional_ops.foldl( 185 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 186 r = gradients_impl.gradients(r, v)[0] 187 self.assertAllEqual(720.0, self.evaluate(r)) 188 189 r = functional_ops.foldr( 190 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 191 r = gradients_impl.gradients(r, v)[0] 192 self.assertAllEqual(720.0, self.evaluate(r)) 193 # pylint: enable=unnecessary-lambda 194 195 @test_util.run_in_graph_and_eager_modes 196 def testScan_Simple(self): 197 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 198 v = constant_op.constant(2.0, name="v") 199 200 # pylint: disable=unnecessary-lambda 201 r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) 202 self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) 203 204 r = functional_ops.scan( 205 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 206 self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) 207 # pylint: enable=unnecessary-lambda 208 209 @test_util.run_in_graph_and_eager_modes 210 def testScan_Reverse(self): 211 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 212 v = constant_op.constant(2.0, name="v") 213 214 # pylint: disable=unnecessary-lambda 215 r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems, 216 reverse=True) 217 self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r)) 218 r = functional_ops.scan( 219 lambda a, x: math_ops.multiply(a, x), elems, initializer=v, 220 reverse=True) 221 self.assertAllEqual([1440., 1440., 720., 240., 60., 12.], 222 self.evaluate(r)) 223 # pylint: enable=unnecessary-lambda 224 225 @test_util.run_in_graph_and_eager_modes 226 def testScan_SingleInputMultiOutput(self): 227 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 228 initializer = (np.array(1.0), np.array(-1.0)) 229 r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, 230 initializer) 231 r_value = self.evaluate(r) 232 233 self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) 234 self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) 235 236 @test_util.run_in_graph_and_eager_modes 237 def testScan_MultiInputSingleOutput(self): 238 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 239 initializer = np.array(1.0) 240 # Multiply a * 1 each time 241 r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), 242 (elems + 1, -elems), initializer) 243 self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) 244 245 @test_util.run_in_graph_and_eager_modes 246 def testScan_MultiInputSameTypeOutput(self): 247 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 248 r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), 249 (elems, -elems)) 250 r_value = self.evaluate(r) 251 self.assertAllEqual(np.cumsum(elems), r_value[0]) 252 self.assertAllEqual(np.cumsum(-elems), r_value[1]) 253 254 @test_util.run_in_graph_and_eager_modes 255 def testScan_MultiOutputMismatchedInitializer(self): 256 elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) 257 initializer = np.array(1.0) 258 # Multiply a * 1 each time 259 with self.assertRaisesRegexp( 260 ValueError, "two structures don't have the same nested structure"): 261 functional_ops.scan(lambda a, x: (a, -a), elems, initializer) 262 263 @test_util.run_deprecated_v1 264 def testScan_Scoped(self): 265 with self.cached_session() as sess: 266 with variable_scope.variable_scope("root") as varscope: 267 elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 268 269 r = functional_ops.scan(simple_scoped_fn, elems) 270 # Check that we have the one variable we asked for here. 271 self.assertEqual(len(variables.trainable_variables()), 1) 272 self.assertEqual(variables.trainable_variables()[0].name, 273 "root/body/two:0") 274 sess.run([variables.global_variables_initializer()]) 275 results = np.array([1, 6, 18, 44, 98, 208]) 276 self.assertAllEqual(results, self.evaluate(r)) 277 278 # Now let's reuse our single variable. 279 varscope.reuse_variables() 280 r = functional_ops.scan(simple_scoped_fn, elems, initializer=2) 281 self.assertEqual(len(variables.trainable_variables()), 1) 282 results = np.array([6, 16, 38, 84, 178, 368]) 283 self.assertAllEqual(results, self.evaluate(r)) 284 285 @test_util.run_in_graph_and_eager_modes 286 def testScanFoldl_Nested(self): 287 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") 288 inner_elems = constant_op.constant([0.5, 0.5], name="data") 289 290 def r_inner(a, x): 291 return functional_ops.foldl( 292 lambda b, y: b * y * x, inner_elems, initializer=a) 293 294 r = functional_ops.scan(r_inner, elems) 295 296 # t == 0 (returns 1) 297 # t == 1, a == 1, x == 2 (returns 1) 298 # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 299 # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 300 # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) 301 # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 302 # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 303 # t == 3, a == 2.25, x == 4 (returns 9) 304 # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 305 # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 306 self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) 307 308 @test_util.run_deprecated_v1 309 def testScan_Control(self): 310 with self.cached_session() as sess: 311 s = array_ops.placeholder(dtypes.float32, shape=[None]) 312 b = array_ops.placeholder(dtypes.bool) 313 314 with ops.control_dependencies([b]): 315 c = functional_ops.scan(lambda a, x: x * a, s) 316 self.assertAllClose( 317 np.array([1.0, 3.0, 9.0]), sess.run(c, {s: [1, 3, 3], 318 b: True})) 319 320 @test_util.run_deprecated_v1 321 def testScan_Grad(self): 322 with self.cached_session(): 323 elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") 324 v = constant_op.constant(2.0, name="v") 325 326 # pylint: disable=unnecessary-lambda 327 r = functional_ops.scan( 328 lambda a, x: math_ops.multiply(a, x), elems, initializer=v) 329 # pylint: enable=unnecessary-lambda 330 r = gradients_impl.gradients(r, v)[0] 331 self.assertAllEqual(873.0, self.evaluate(r)) 332 333 @test_util.run_deprecated_v1 334 def testScanGradientWithPartStopGradient(self): 335 a = variables.Variable(0.0, name="a") 336 b = variables.Variable(0.0, name="b") 337 elems = array_ops.zeros(5) 338 l0, l1 = functional_ops.scan( 339 lambda elem_, input_: (a, b), elems, initializer=(0., 0.)) 340 loss = l0 + array_ops.stop_gradient(l1) 341 grad = gradients_impl.gradients(ys=[loss], xs=[a, b]) 342 with self.test_session(use_gpu=True) as sess: 343 self.evaluate(variables.global_variables_initializer()) 344 self.evaluate(grad) 345 346 @test_util.run_in_graph_and_eager_modes 347 def testFoldShape(self): 348 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 349 350 def fn(_, current_input): 351 return current_input 352 353 initializer = constant_op.constant([0, 0, 0]) 354 y = functional_ops.foldl(fn, x, initializer=initializer) 355 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 356 357 @test_util.run_in_graph_and_eager_modes 358 def testScanShape(self): 359 x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) 360 361 def fn(_, current_input): 362 return current_input 363 364 initializer = constant_op.constant([0, 0, 0]) 365 y = functional_ops.scan(fn, x, initializer=initializer) 366 self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) 367 368 # TODO(akshayka): this test fails in eager: the iterable is of length 0 so 369 # so the body of the while loop never executes 370 @test_util.run_deprecated_v1 371 def testScanEmptyTensor(self): 372 with self.cached_session(): 373 x = functional_ops.scan( 374 lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4])) 375 self.assertAllEqual([0, 2, 4], x.get_shape()) 376 self.assertAllEqual(x.get_shape(), self.evaluate(x).shape) 377 378 @test_util.run_deprecated_v1 379 def testScanUnknownShape(self): 380 x = array_ops.placeholder(dtypes.float32) 381 initializer = array_ops.placeholder(dtypes.float32) 382 383 def fn(_, current_input): 384 return current_input 385 386 y = functional_ops.scan(fn, x, initializer=initializer) 387 self.assertIs(None, y.get_shape().dims) 388 389 @test_util.run_deprecated_v1 390 def testScanVaryingShape(self): 391 with self.cached_session() as sess: 392 x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2]) 393 x_t = array_ops.transpose(x) 394 # scan over dimension 0 (with shape None) 395 result = functional_ops.scan(lambda a, x: a + x, x) 396 # scanned over transposed dimension 0 (with shape 2) 397 result_t = functional_ops.scan(lambda a, x: a + x, x_t, infer_shape=False) 398 # ensure gradients can be calculated 399 result_grad = gradients_impl.gradients(result, [x])[0] 400 result_t_grad = gradients_impl.gradients(result_t, [x_t])[0] 401 402 # smoke test to ensure they all evaluate 403 sess.run([result, result_t, result_grad, result_t_grad], 404 feed_dict={x: [[1.0, 2.0]]}) 405 406 @test_util.run_deprecated_v1 407 def testRemoteFunction(self): 408 worker_config = config_pb2.ConfigProto() 409 worker_config.device_count["CPU"] = 2 410 worker, _ = test_util.create_local_cluster( 411 1, 1, worker_config=worker_config) 412 413 @function.Defun(dtypes.int32, dtypes.int32) 414 def _remote_fn(a, b): 415 return math_ops.multiply(a, b) 416 417 with ops.device("/job:ps/task:0"): 418 a = variables.Variable(2, dtype=dtypes.int32) 419 b = variables.Variable(3, dtype=dtypes.int32) 420 421 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 422 remote_op = functional_ops.remote_call( 423 args=[a, b], 424 Tout=[dtypes.int32], 425 f=_remote_fn, 426 target="/job:worker/replica:0/task:0/cpu:1") 427 428 with session.Session(worker[0].target) as sess: 429 self.evaluate(variables.global_variables_initializer()) 430 mul = self.evaluate(remote_op) 431 self.assertEqual(mul, [6]) 432 433 @test_util.run_deprecated_v1 434 def testRemoteFunctionDirectSession(self): 435 worker_config = config_pb2.ConfigProto() 436 worker_config.device_count["CPU"] = 2 437 438 @function.Defun(dtypes.int32, dtypes.int32) 439 def _remote_fn(a, b): 440 return math_ops.multiply(a, b) 441 442 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 443 a = variables.Variable(2, dtype=dtypes.int32) 444 b = variables.Variable(3, dtype=dtypes.int32) 445 446 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 447 remote_op = functional_ops.remote_call( 448 args=[a, b], 449 Tout=[dtypes.int32], 450 f=_remote_fn, 451 target="/job:localhost/replica:0/task:0/cpu:1") 452 453 with self.test_session(config=worker_config) as sess: 454 self.evaluate(variables.global_variables_initializer()) 455 mul = self.evaluate(remote_op) 456 self.assertEqual(mul, [6]) 457 458 @test_util.run_deprecated_v1 459 def testRemoteFunctionSameDeviceDirectSession(self): 460 461 @function.Defun(dtypes.int32, dtypes.int32) 462 def _remote_fn(a, b): 463 return math_ops.multiply(a, b) 464 465 with ops.device("/cpu:0"): 466 a = variables.Variable(2, dtype=dtypes.int32) 467 b = variables.Variable(3, dtype=dtypes.int32) 468 469 with ops.device("/cpu:0"): 470 remote_op = functional_ops.remote_call( 471 args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") 472 473 with self.cached_session() as sess: 474 self.evaluate(variables.global_variables_initializer()) 475 mul = self.evaluate(remote_op) 476 self.assertEqual(mul, [6]) 477 478 @test_util.run_deprecated_v1 479 def testRemoteFunctionCPUGPU(self): 480 if not test_util.is_gpu_available(): 481 self.skipTest("No GPU available") 482 483 @function.Defun(dtypes.float32, dtypes.float32) 484 def _remote_fn(a, b): 485 return math_ops.multiply(a, b) 486 487 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 488 a = variables.Variable(2, dtype=dtypes.float32) 489 b = variables.Variable(3, dtype=dtypes.float32) 490 491 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 492 remote_op = functional_ops.remote_call( 493 args=[a, b], 494 Tout=[dtypes.float32], 495 f=_remote_fn, 496 target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 497 498 with self.cached_session() as sess: 499 self.evaluate(variables.global_variables_initializer()) 500 mul = self.evaluate(remote_op) 501 self.assertEqual(mul, 9.0) 502 503 @test_util.run_deprecated_v1 504 def testRemoteFunctionGPUCPU(self): 505 if not test_util.is_gpu_available(): 506 self.skipTest("No GPU available") 507 508 @function.Defun(dtypes.float32, dtypes.float32) 509 def _remote_fn(a, b): 510 return math_ops.multiply(a, b) 511 512 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 513 a = variables.Variable(2, dtype=dtypes.float32) 514 b = variables.Variable(3, dtype=dtypes.float32) 515 516 with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): 517 remote_op = functional_ops.remote_call( 518 args=[a, b], 519 Tout=[dtypes.float32], 520 f=_remote_fn, 521 target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 522 523 with self.cached_session() as sess: 524 self.evaluate(variables.global_variables_initializer()) 525 mul = self.evaluate(remote_op) 526 self.assertEqual(mul, 9.0) 527 528 @test_util.run_deprecated_v1 529 def testRemoteFunctionGPUCPUStrings(self): 530 if not test_util.is_gpu_available(): 531 self.skipTest("No GPU available") 532 533 @function.Defun(dtypes.string) 534 def _remote_fn(inp): 535 return array_ops.identity(inp) 536 537 a = array_ops.constant("a") 538 539 with ops.device("/gpu:0"): 540 remote_op = functional_ops.remote_call( 541 args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") 542 543 with self.cached_session() as sess: 544 ret = self.evaluate(remote_op) 545 self.assertAllEqual(ret, [b"a"]) 546 547 @test_util.run_deprecated_v1 548 def testRemoteFunctionCrossProcess(self): 549 workers, _ = test_util.create_local_cluster(2, 1) 550 551 @function.Defun(dtypes.float32, dtypes.float32) 552 def _remote_fn(a, b): 553 return math_ops.multiply(a, b) 554 555 with ops.device("/job:ps/task:0"): 556 a = variables.Variable(2, dtype=dtypes.float32) 557 b = variables.Variable(3, dtype=dtypes.float32) 558 559 with ops.device("/job:worker/replica:0/task:0/cpu:0"): 560 remote_op = functional_ops.remote_call( 561 args=[a, b], 562 Tout=[dtypes.float32], 563 f=_remote_fn, 564 target="/job:worker/replica:0/task:1/cpu:0")[0] + 3.0 565 566 with session.Session(workers[0].target) as sess: 567 self.evaluate(variables.global_variables_initializer()) 568 mul = self.evaluate(remote_op) 569 self.assertEqual(mul, 9) 570 571 @test_util.run_deprecated_v1 572 def testIf(self): 573 574 @function.Defun(dtypes.float32) 575 def Twice(x): 576 return x * 2 577 578 @function.Defun(dtypes.float32) 579 def Thrice(x): 580 return x * 3 + 1 581 582 with self.test_session(use_gpu=False) as sess: 583 584 x = array_ops.placeholder(dtypes.float32) 585 ret = functional_ops.If(math_ops.greater(x, 0), [x], Twice, Thrice)[0] 586 587 self.assertAllEqual(sess.run(ret, feed_dict={x: 9.}), 18.) 588 self.assertAllEqual(sess.run(ret, feed_dict={x: -8.}), -23.) 589 self.assertAllEqual(sess.run(ret, feed_dict={x: 0.}), 1.) 590 591 def testWhile(self): 592 593 for use_gpu in (True, False): 594 with ops.Graph().as_default() as g: 595 596 @function.Defun(*[dtypes.float32] * 2) 597 def Cond(n, unused_x): 598 return n > 0 599 600 @function.Defun(*[dtypes.float32] * 2) 601 def Body(n, x): 602 return n - 1, x + n 603 604 def Run(sess, n): 605 return sess.run(functional_ops.While([n, 0.], Cond, Body))[1] 606 607 with self.session(graph=g, use_gpu=use_gpu) as sess: 608 self.assertAllEqual(Run(sess, 20.), 210.) 609 self.assertAllEqual(Run(sess, 100.), 5050.) 610 611 # Like above, but using int32 in order to ensure that int32 tensors don't get 612 # copied to the GPU during the application of the while. 613 def testWhileInt32(self): 614 with ops.Graph().as_default() as g: 615 616 @function.Defun(*[dtypes.int32] * 2) 617 def Cond(n, unused_x): 618 return n > 0 619 620 @function.Defun(*[dtypes.int32] * 2) 621 def Body(n, x): 622 return n - 1, x + n 623 624 def Run(sess, n): 625 return sess.run(functional_ops.While([n, 0], Cond, Body))[1] 626 627 with self.session(graph=g, use_gpu=True) as sess: 628 self.assertAllEqual(Run(sess, 20), 210) 629 self.assertAllEqual(Run(sess, 100), 5050) 630 631 @test_util.run_deprecated_v1 632 def testWhileLowering(self): 633 634 def Run(n, fetch_by_name): 635 for use_gpu in (True, False): 636 with ops.Graph().as_default() as g: 637 638 @function.Defun(*[dtypes.float32] * 2) 639 def Cond(n, unused_x): 640 return n > 0 641 642 @function.Defun(*[dtypes.float32] * 2) 643 def Body(n, x): 644 return n - 1, x + n 645 646 # outputs: [0, n*(n+1)/2] 647 outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") 648 649 # `outputs` is the list of output tensors of the While op. We 650 # arbitrarily choose the 0th tensor to get the While op and set the 651 # lowering attribute on it. 652 outputs[0].op._set_attr("_lower_using_switch_merge", 653 attr_value_pb2.AttrValue(b=True)) 654 if not fetch_by_name: 655 fetch = outputs[1] 656 else: 657 fetch = "my_while:1" 658 with self.session(graph=g, use_gpu=use_gpu) as sess: 659 return self.evaluate(fetch) 660 661 self.assertAllEqual(Run(20., False), 210.) 662 self.assertAllEqual(Run(20., True), 210.) 663 self.assertAllEqual(Run(100., False), 5050.) 664 self.assertAllEqual(Run(100., True), 5050.) 665 666 @test_util.run_v1_only("b/120545219") 667 @test_util.disable_xla("b/123337890") # Different error message 668 def testWhileError(self): 669 for use_gpu in (True, False): 670 with ops.Graph().as_default() as g: 671 672 @function.Defun(*[dtypes.float32] * 2) 673 def Cond(n, unused_x): 674 return n > 0 675 676 @function.Defun(*[dtypes.float32] * 2) 677 def CondReturnsTooManyArgs(n, x): 678 return n > 0, x 679 680 @function.Defun(*[dtypes.float32] * 2) 681 def Body(n, x): 682 return n - 1, x + n 683 684 @function.Defun(*[dtypes.float32] * 2) 685 def BodyReturnsTooManyArgs(n, x): 686 return n - 1, x + n, x 687 688 with self.session(graph=g, use_gpu=use_gpu): 689 with self.assertRaisesRegexp( 690 errors.InvalidArgumentError, 691 "Expected a single scalar.*got 2 tensors."): 692 functional_ops.While([5., 0.], CondReturnsTooManyArgs, 693 Body)[0].eval() 694 with self.assertRaisesRegexp( 695 errors.InvalidArgumentError, 696 "While loop body returned 3 arguments. Expected: 2"): 697 functional_ops.While([5., 0.], Cond, 698 BodyReturnsTooManyArgs)[0].eval() 699 700 def testWhileInMultipleSubgraphs(self): 701 702 for use_gpu in (True, False): 703 with ops.Graph().as_default() as g: 704 705 @function.Defun(*[dtypes.float32] * 2) 706 def Cond(n, x): # pylint: disable=unused-argument 707 return n > 0 708 709 @function.Defun(*[dtypes.float32] * 2) 710 def Body(n, x): 711 return n - 1, x + n 712 713 with self.session(graph=g, use_gpu=use_gpu) as sess: 714 n = array_ops.placeholder(dtypes.float32) 715 _, result = functional_ops.While([n, 0.], Cond, Body) 716 c = constant_op.constant(37.) 717 718 self.assertAllEqual(210., sess.run(result, feed_dict={n: 20.})) 719 self.assertAllEqual(5050., sess.run(result, feed_dict={n: 100.})) 720 # Test that the result is the same when we run a different subgraph. 721 self.assertAllEqual(5050., 722 sess.run([result, c], feed_dict={n: 100.})[0]) 723 724 # pylint: disable=cell-var-from-loop 725 def testWhileCapturedInputs(self): 726 for use_gpu in (True, False): 727 with ops.Graph().as_default() as g: 728 v = variables.Variable(1.0) 729 730 def TestCond(n, *args): 731 del args 732 return n < 10 733 734 @function.Defun(*[dtypes.float32] * 2) 735 def TestUnary(n, x): 736 return math_ops.add(n, 1), x + n + v 737 738 @function.Defun(*[dtypes.float32] * 3) 739 def TestBinary(n, x, x2): 740 return math_ops.add(n, 1), x + n + v, x2 + v 741 742 with self.session(graph=g, use_gpu=use_gpu) as sess: 743 result_unary = functional_ops.While( 744 [1.0, 0.], 745 function.Defun(*[dtypes.float32] * 2)(TestCond), TestUnary) 746 result_binary = functional_ops.While( 747 [1.0, 0., 0.], 748 function.Defun(*[dtypes.float32] * 3)(TestCond), TestBinary) 749 self.evaluate(variables.global_variables_initializer()) 750 assert len(result_unary) == 2 751 self.assertEqual([10.0, 54.0], self.evaluate(result_unary)) 752 assert len(result_binary) == 3 753 self.assertEqual([10.0, 54.0, 9.0], self.evaluate(result_binary)) 754 755 def TestCondCapture(n, *args): 756 del args 757 return math_ops.cast(n, dtypes.float32) + v < 10 758 759 with self.assertRaises(ValueError): 760 _ = functional_ops.While( 761 [1], 762 function.Defun(dtypes.int32)(TestCondCapture), 763 function.Defun(dtypes.int32, dtypes.float32)(TestUnary)) 764 765 # pylint: enable=cell-var-from-loop 766 767 def _tfSum(self, use_gpu, rewrite_with_while): 768 with ops.Graph().as_default() as g: 769 with self.session(graph=g, use_gpu=use_gpu) as sess: 770 771 @function.Defun(dtypes.int32, dtypes.float32) 772 def Body(n, x): 773 return x + math_ops.cast(n, dtypes.float32) 774 775 xs = [ 776 # 1 + 2 + ... + 20 777 functional_ops.For( 778 1, 21, 1, [0.], Body, rewrite_with_while=rewrite_with_while)[0], 779 # 100 + 99 + ... + 1 780 functional_ops.For( 781 100, 0, -1, [0.], Body, rewrite_with_while=rewrite_with_while) 782 [0], 783 ] 784 xvals = self.evaluate(xs) 785 self.assertAllEqual(210, xvals[0]) 786 self.assertAllEqual(5050, xvals[1]) 787 788 def testFor(self): 789 for use_gpu in (True, False): 790 self._tfSum(use_gpu, False) 791 792 def testForWithWhile(self): 793 for use_gpu in (True, False): 794 self._tfSum(use_gpu, True) 795 796 def testForWithWhileNaming(self): 797 g = ops.Graph() 798 with g.as_default(): 799 800 @function.Defun(dtypes.int32, dtypes.float32, func_name="TestBody") 801 def TestBody(n, x): 802 return x + math_ops.cast(n, dtypes.float32) 803 804 _ = functional_ops.For( 805 1, 21, 1, [0.], TestBody, rewrite_with_while=True)[0] 806 807 names = [] 808 for func in g.as_graph_def().library.function: 809 names.append(func.signature.name) 810 self.assertTrue("TestBody" in names) 811 self.assertTrue("TestBody_Cond" in names) 812 self.assertTrue("TestBody_Body" in names) 813 814 @test_util.run_deprecated_v1 815 def testForCapturedInputs(self): 816 v = variables.Variable(1.0) 817 818 @function.Defun(dtypes.int32) 819 def TestNullary(n): 820 v + math_ops.cast(n, dtypes.float32) # pylint: disable=expression-not-assigned 821 822 @function.Defun(dtypes.int32, dtypes.float32) 823 def TestUnary(n, x): 824 return x + math_ops.cast(n, dtypes.float32) + v 825 826 @function.Defun(dtypes.int32, dtypes.float32, dtypes.float32) 827 def TestBinary(n, x, x2): 828 return x + math_ops.cast(n, dtypes.float32) + v, x2 + v 829 830 for rewrite_with_while in (True, False): 831 use_gpu = not rewrite_with_while 832 with self.test_session(use_gpu=use_gpu) as sess: 833 result_nullary = functional_ops.For( 834 1, 10, 1, [], TestNullary, 835 rewrite_with_while=rewrite_with_while) 836 result_unary = functional_ops.For( 837 1, 10, 1, [0.], TestUnary, 838 rewrite_with_while=rewrite_with_while) 839 result_binary = functional_ops.For( 840 1, 10, 1, [0., 0.], TestBinary, 841 rewrite_with_while=rewrite_with_while) 842 self.evaluate(variables.global_variables_initializer()) 843 assert not result_nullary 844 # The nullary variant doesn't return anything so we can't easily run it. 845 # As a total hack, fetch the operation by name and run it. 846 sess.run(ops.get_default_graph().get_operation_by_name( 847 "While" if rewrite_with_while else "For")) 848 assert len(result_unary) == 1 849 self.assertEqual([54.0], self.evaluate(result_unary)) 850 assert len(result_binary) == 2 851 self.assertEqual([54.0, 9.0], self.evaluate(result_binary)) 852 853 def _tfMLP(self, xval, wsval, bsval, rewrite_with_while): 854 # On GPU, don't rewrite using a while loop. 855 use_gpu = not rewrite_with_while 856 with self.test_session(use_gpu=use_gpu): 857 858 @function.Defun(dtypes.int32, *[dtypes.float64] * 3) 859 def MLP(i, a, ws, bs): 860 a = math_ops.tanh(math_ops.matmul(a, ws[i, :]) + bs[i, :]) 861 return a, ws, bs 862 863 ret = functional_ops.For( 864 0, 865 wsval.shape[0], 866 1, [xval, wsval, bsval], 867 MLP, 868 rewrite_with_while=rewrite_with_while)[0] 869 870 return self.evaluate(ret) 871 872 def _npMLP(self, xval, wsval, bsval): 873 for i in range(wsval.shape[0]): 874 xval = np.tanh(np.dot(xval, wsval[i, :]) + bsval[i, :]) 875 return xval 876 877 def _testForMLP(self, rewrite_with_while): 878 # We construct a 5-layer Multi-Layer Perceptron network here. 879 # Each layer have the same number of hidden unites (3), and the 880 # activation function is tanh(). We feed the input (xval) with 881 # batch size 2. 882 xval = np.random.normal(size=(2, 3)) 883 wsval = np.random.normal(size=(5, 3, 3)) 884 bsval = np.random.normal(size=(5, 3)) 885 np_ans = self._npMLP(xval, wsval, bsval) 886 tf_for_ans = self._tfMLP(xval, wsval, bsval, rewrite_with_while) 887 self.assertAllClose(np_ans, tf_for_ans) 888 889 @test_util.run_deprecated_v1 890 def testForMLP(self): 891 self._testForMLP(False) 892 893 @test_util.run_deprecated_v1 894 def testForMLPWhile(self): 895 self._testForMLP(True) 896 897 @test_util.run_v1_only("b/120545219") 898 def testForError(self): 899 900 @function.Defun(dtypes.int32, dtypes.float32) 901 def Foo(i, v): 902 return math_ops.cast(i, dtypes.float32) + v 903 904 @function.Defun(dtypes.int32, dtypes.float32) 905 def ReturnsTooManyArgs(unused_i, v): 906 return v, v 907 908 with self.test_session(use_gpu=True): 909 with self.assertRaisesRegexp(errors.InvalidArgumentError, 910 "must be a scalar"): 911 functional_ops.For([0], 10, 1, [0.0], Foo)[0].eval() 912 with self.assertRaisesRegexp(errors.InvalidArgumentError, 913 "Invalid start/limit/delta"): 914 functional_ops.For(0, 10, -1, [0.0], Foo)[0].eval() 915 with self.assertRaisesRegexp( 916 errors.InvalidArgumentError, 917 "For loop body returned 2 arguments. Expected: 1"): 918 functional_ops.For(0, 10, 1, [0.0], ReturnsTooManyArgs)[0].eval() 919 920 @test_util.run_deprecated_v1 921 def testGradient(self): 922 923 @function.Defun(dtypes.float32) 924 def Poly(x): 925 # y = 2x^3+3x^2+4x+8 926 return 2 * x * x * x + 3 * x * x + 4 * x + 8 927 928 @function.Defun(dtypes.float32) 929 def Grad(x): 930 # dy/dx = dy/dy * dy/dx = 1.0 * (6x^2+6x+4) 931 return functional_ops.Gradient([x, 1.0], Poly)[0] 932 933 with self.test_session(use_gpu=False) as sess: 934 a = constant_op.constant(0.) 935 avals = [Poly(a), Grad(a)] 936 b = constant_op.constant(1.) 937 bvals = [Poly(b), Grad(b)] 938 self.assertAllEqual(self.evaluate(avals), [8., 4.]) 939 self.assertAllEqual(self.evaluate(bvals), [17., 16.]) 940 941 942# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the 943# below test cases. 944class PartitionedCallTest(test.TestCase): 945 946 @test_util.run_deprecated_v1 947 def testBasicSingleDevice(self): 948 949 @function.Defun(*[dtypes.float32] * 2) 950 def Body(x, y): 951 with ops.device("/cpu:0"): 952 a = x + x 953 b = y + y 954 return a + b 955 956 output, = self.evaluate( 957 functional_ops.partitioned_call( 958 args=[constant_op.constant(1.), 959 constant_op.constant(2.)], f=Body)) 960 self.assertEqual(output, 6.) 961 962 @test_util.run_deprecated_v1 963 def testBasicMultiDevice(self): 964 config = config_pb2.ConfigProto(device_count={"CPU": 3}) 965 966 @function.Defun(*[dtypes.float32] * 2) 967 def Body(x, y): 968 # if x = 1, y = 2, ... 969 with ops.device("/cpu:0"): 970 # a:= 1 + 1 = 2 971 a = x + x 972 with ops.device("/cpu:1"): 973 # b:= 2 + 2 = 4 974 b = a + y 975 with ops.device("/cpu:2"): 976 # c:= 2 + 4 = 6 977 c = a + b 978 # a + b + c = 2 + 4 + 6 = 12 979 return a + b + c 980 981 with self.test_session(config=config): 982 output, = functional_ops.partitioned_call( 983 args=[constant_op.constant(1.), 984 constant_op.constant(2.)], f=Body) 985 self.assertEqual(output.eval(), 12.) 986 987 @test_util.run_deprecated_v1 988 def testBasicMultiDeviceGPU(self): 989 if not test_util.is_gpu_available(): 990 return 991 992 @function.Defun(*[dtypes.float32] * 2) 993 def Body(x, y): 994 with ops.device("/gpu:0"): 995 a = x + x 996 b = y + y 997 with ops.device("/cpu:0"): 998 c = a + b 999 return c 1000 1001 output, = self.evaluate( 1002 functional_ops.partitioned_call( 1003 args=[constant_op.constant(1.), 1004 constant_op.constant(2.)], f=Body)) 1005 self.assertEqual(output, 6.) 1006 1007 @test_util.run_deprecated_v1 1008 def testBasicNoDeviceAnnotations(self): 1009 1010 @function.Defun(*[dtypes.float32] * 2) 1011 def Body(x, y): 1012 a = x + x 1013 b = y + y 1014 return a + b 1015 1016 output, = self.evaluate( 1017 functional_ops.partitioned_call( 1018 args=[constant_op.constant(1.), 1019 constant_op.constant(2.)], f=Body)) 1020 self.assertEqual(output, 6.) 1021 1022 @test_util.run_deprecated_v1 1023 def testShardsRunOnRequestedDevices(self): 1024 config = config_pb2.ConfigProto(device_count={"CPU": 4}) 1025 1026 @function.Defun() 1027 def Body(): 1028 # Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on 1029 # which the resource was created, so that we can verify that ops were 1030 # actually run on the requested devices. 1031 # 1032 # TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the 1033 # name of the device on which a resource lives / for determining the 1034 # device on which an op ran. 1035 with ops.device("/cpu:0"): 1036 s1 = iterator_ops.Iterator.from_structure( 1037 (dtypes.float32,)).string_handle() 1038 with ops.device("/cpu:1"): 1039 s2 = iterator_ops.Iterator.from_structure( 1040 (dtypes.float32,)).string_handle() 1041 with ops.device("/cpu:2"): 1042 s3 = iterator_ops.Iterator.from_structure( 1043 (dtypes.float32,)).string_handle() 1044 return s1, s2, s3 1045 1046 with self.test_session(config=config, use_gpu=True) as sess: 1047 outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body)) 1048 self.assertIn(compat.as_bytes("CPU:0"), outputs[0]) 1049 self.assertIn(compat.as_bytes("CPU:1"), outputs[1]) 1050 self.assertIn(compat.as_bytes("CPU:2"), outputs[2]) 1051 1052 @test_util.run_deprecated_v1 1053 def testAssignAddResourceVariable(self): 1054 1055 v = resource_variable_ops.ResourceVariable(1.0) 1056 1057 @function.Defun() 1058 def AssignAdd(): 1059 v.assign_add(1.0) 1060 1061 op = functional_ops.partitioned_call( 1062 args=AssignAdd.captured_inputs, f=AssignAdd) 1063 _ = self.evaluate(variables.global_variables_initializer()) 1064 _ = self.evaluate(op) 1065 value = self.evaluate(v.read_value()) 1066 self.assertEqual(value, 2.0) 1067 1068 @test_util.run_deprecated_v1 1069 def testFunctionWithResourcesOnDifferentDevices(self): 1070 if not test_util.is_gpu_available(): 1071 self.skipTest("No GPUs available.") 1072 1073 with ops.device("/cpu:0"): 1074 v_cpu_zero = resource_variable_ops.ResourceVariable( 1075 [0.0, 1.0, 2.0], name="v_cpu_zero") 1076 1077 with ops.device("/cpu:1"): 1078 v_cpu_one = resource_variable_ops.ResourceVariable( 1079 [0.0, 1.0, 2.0], name="v_cpu_one") 1080 1081 with ops.device("/gpu:0"): 1082 v_gpu = resource_variable_ops.ResourceVariable( 1083 [0.0, 1.0, 2.0], name="v_gpu") 1084 1085 def sum_gather(): 1086 cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2])) 1087 also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2])) 1088 gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2])) 1089 return cpu_result, also_cpu_result, gpu_result 1090 1091 defined = function.Defun()(sum_gather) 1092 with self.test_session( 1093 config=config_pb2.ConfigProto( 1094 allow_soft_placement=False, 1095 log_device_placement=True, 1096 device_count={"CPU": 2})) as sess: 1097 self.evaluate(variables.global_variables_initializer()) 1098 expected = self.evaluate(sum_gather()) 1099 result = sess.run( 1100 functional_ops.partitioned_call( 1101 args=defined.captured_inputs, f=defined)) 1102 self.assertAllEqual(expected, result) 1103 1104 # Use an invalid executor name to test the plumbing of the executor_type attr. 1105 @test_util.run_v1_only("b/120545219") 1106 def testExecutorTypeAttrExecutorNotFound(self): 1107 @function.Defun(dtypes.int32) 1108 def AddFive(x): 1109 return x + 5 1110 1111 op = functional_ops.partitioned_call( 1112 args=[constant_op.constant([1, 2, 3], dtype=dtypes.int32)], 1113 f=AddFive, 1114 executor_type="NON_EXISTENT_EXECUTOR") 1115 with self.assertRaisesRegexp(errors.NotFoundError, 1116 "NON_EXISTENT_EXECUTOR"): 1117 self.evaluate(op) 1118 1119 1120@test_util.run_all_in_graph_and_eager_modes 1121@test_util.with_control_flow_v2 1122class FunctionalOpsCaseTest(test.TestCase): 1123 1124 def testCase(self): 1125 @eager_function.defun 1126 def two(x): 1127 return x * 2 1128 1129 @eager_function.defun 1130 def three(x): 1131 return x * 3 1132 1133 @eager_function.defun 1134 def four(x): 1135 return x * 4 1136 1137 def f(branch, x): 1138 tmpl = array_ops.zeros_like(x) 1139 return array_ops.identity(gen_functional_ops.case( 1140 branch, input=[x], Tout=[dtypes.float32], 1141 branches=[f.get_concrete_function(tmpl) 1142 for f in (two, three, four)])[0]) 1143 one = array_ops.ones([]) 1144 self.assertAllEqual(np.float32(2), self.evaluate(f(0, one))) 1145 self.assertAllEqual(np.float32(3), self.evaluate(f(1, one))) 1146 self.assertAllEqual(np.float32(4), self.evaluate(f(2, one))) 1147 self.assertAllEqual(np.float32(4), self.evaluate(f(-1, one))) # <0 default 1148 self.assertAllEqual(np.float32(4), self.evaluate(f(6, one))) # >=N default 1149 1150 1151if __name__ == "__main__": 1152 test.main() 1153 1154# pylint: enable=invalid-name 1155