1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for running legacy optimizer code with DistributionStrategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22import numpy 23 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.distribute import combinations as ds_combinations 26from tensorflow.python.distribute import reduce_util 27from tensorflow.python.distribute import strategy_combinations 28from tensorflow.python.eager import context 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import test_combinations as combinations 32from tensorflow.python.keras.distribute import optimizer_combinations 33from tensorflow.python.keras.distribute.test_example import batchnorm_example 34from tensorflow.python.keras.distribute.test_example import minimize_loss_example 35from tensorflow.python.keras.layers import core 36from tensorflow.python.keras.optimizer_v2 import optimizer_v2 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_ops 39from tensorflow.python.ops import control_flow_v2_toggles 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables as variables_lib 43from tensorflow.python.ops.losses import losses_impl 44from tensorflow.python.platform import test 45 46 47VAR_MAP_V1 = { 48 "GradientDescent": ("dense/kernel", "dense/bias"), 49 "Adagrad": ("dense/kernel/Adagrad", "dense/kernel", "dense/bias/Adagrad", 50 "dense/bias"), 51 "Ftrl": ("dense/kernel/Ftrl", "dense/kernel", "dense/bias/Ftrl", 52 "dense/bias", "dense/kernel/Ftrl_1", "dense/bias/Ftrl_1"), 53 "RMSProp": ("dense/kernel", "dense/bias/RMSProp", "dense/bias/RMSProp_1", 54 "dense/bias", "dense/kernel/RMSProp_1", "dense/kernel/RMSProp") 55} 56 57VAR_MAP_V2 = { 58 "SGD": ("dense/bias", "SGD/learning_rate", "SGD/decay", "SGD/iter", 59 "dense/kernel", "SGD/momentum"), 60 "Adagrad": 61 ("Adagrad/iter", "dense/bias", "dense/kernel", "Adagrad/learning_rate", 62 "Adagrad/decay", "Adagrad/dense/kernel/accumulator", 63 "Adagrad/dense/bias/accumulator") 64} 65 66 67class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): 68 69 def _get_iterator(self, strategy, input_fn): 70 iterator = strategy.make_input_fn_iterator(lambda _: input_fn()) 71 self.evaluate(iterator.initializer) 72 return iterator 73 74 @ds_combinations.generate( 75 combinations.times( 76 optimizer_combinations.distributions_and_v1_optimizers(), 77 combinations.combine(mode=["graph"], use_callable_loss=[True, False]) 78 + combinations.combine(mode=["eager"], use_callable_loss=[True])) + 79 combinations.times( 80 optimizer_combinations.distributions_and_v2_optimizers(), 81 combinations.combine( 82 mode=["graph", "eager"], use_callable_loss=[True])) + 83 combinations.combine( 84 distribution=[strategy_combinations.tpu_strategy], 85 optimizer_fn=optimizer_combinations.optimizers_v2, 86 mode=["graph"], 87 use_callable_loss=[True]) + combinations.combine( 88 distribution=[strategy_combinations.tpu_strategy], 89 optimizer_fn=optimizer_combinations.optimizers_v1, 90 mode=["graph"], 91 use_callable_loss=[True, False])) 92 def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss): 93 with distribution.scope(): 94 optimizer = optimizer_fn() 95 model_fn, dataset_fn, layer = minimize_loss_example( 96 optimizer, use_bias=True, use_callable_loss=use_callable_loss) 97 98 def step_fn(ctx, inputs): 99 del ctx # Unused 100 return distribution.group( 101 distribution.extended.call_for_each_replica( 102 model_fn, args=(inputs,))) 103 104 iterator = self._get_iterator(distribution, dataset_fn) 105 106 def run_step(): 107 return distribution.extended.experimental_run_steps_on_iterator( 108 step_fn, iterator, iterations=2).run_op 109 110 if not context.executing_eagerly(): 111 with self.cached_session() as sess: 112 run_step = sess.make_callable(run_step()) 113 self.evaluate(variables_lib.global_variables_initializer()) 114 115 weights, biases = [], [] 116 for _ in range(5): 117 run_step() 118 weights.append(self.evaluate(layer.kernel)) 119 biases.append(self.evaluate(layer.bias)) 120 121 error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) 122 is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) 123 self.assertTrue(is_not_increasing) 124 125 @ds_combinations.generate( 126 combinations.times( 127 optimizer_combinations.distributions_and_v1_optimizers(), 128 combinations.combine(mode=["graph"], use_callable_loss=[True, False]) 129 + combinations.combine(mode=["eager"], use_callable_loss=[True])) + 130 combinations.times( 131 optimizer_combinations.distributions_and_v2_optimizers(), 132 combinations.combine( 133 mode=["graph", "eager"], use_callable_loss=[True]))) 134 def testTrainNetworkByCallForEachReplica(self, distribution, optimizer_fn, 135 use_callable_loss): 136 with distribution.scope(): 137 optimizer = optimizer_fn() 138 model_fn, dataset_fn, layer = minimize_loss_example( 139 optimizer, use_bias=True, use_callable_loss=use_callable_loss) 140 141 iterator = self._get_iterator(distribution, dataset_fn) 142 143 def run_step(): 144 return distribution.group( 145 distribution.extended.call_for_each_replica( 146 model_fn, args=(iterator.get_next(),))) 147 148 if not context.executing_eagerly(): 149 with self.cached_session() as sess: 150 run_step = sess.make_callable(run_step()) 151 self.evaluate(variables_lib.global_variables_initializer()) 152 153 weights, biases = [], [] 154 for _ in range(10): 155 run_step() 156 157 weights.append(self.evaluate(layer.kernel)) 158 biases.append(self.evaluate(layer.bias)) 159 160 error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) 161 is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) 162 self.assertTrue(is_not_increasing) 163 164 @ds_combinations.generate( 165 combinations.times( 166 optimizer_combinations.distributions_and_v1_and_v2_optimizers(), 167 combinations.combine(mode=["graph", "eager"])) + combinations.combine( 168 distribution=[strategy_combinations.tpu_strategy], 169 optimizer_fn=optimizer_combinations.optimizers_v1_and_v2, 170 mode=["graph"])) 171 def testOptimizerInsideModelFn(self, distribution, optimizer_fn): 172 if (not context.executing_eagerly() and 173 control_flow_v2_toggles.control_flow_v2_enabled()): 174 self.skipTest("b/138751864") 175 created_variables = [] 176 trainable_variables = [] 177 178 def appending_creator(next_creator, **kwargs): 179 v = next_creator(**kwargs) 180 created_variables.append(v.name) 181 if "trainable" in kwargs and kwargs["trainable"]: 182 trainable_variables.append(v.name) 183 return v 184 185 # Creator scope needs to be set before it's used inside 186 # `distribution.scope`. 187 with variable_scope.variable_creator_scope( 188 appending_creator), distribution.scope(): 189 optimizer = optimizer_fn() 190 model_fn, dataset_fn, _ = minimize_loss_example( 191 optimizer, use_bias=True, use_callable_loss=True) 192 193 def step_fn(ctx, inputs): 194 del ctx # Unused 195 return distribution.group( 196 distribution.extended.call_for_each_replica( 197 model_fn, args=(inputs,))) 198 199 iterator = self._get_iterator(distribution, dataset_fn) 200 201 def run_step(): 202 return distribution.extended.experimental_run_steps_on_iterator( 203 step_fn, iterator, iterations=1).run_op 204 205 if not context.executing_eagerly(): 206 with self.cached_session() as sess: 207 run_step = sess.make_callable(run_step()) 208 self.evaluate(variables_lib.global_variables_initializer()) 209 run_step() 210 211 def get_expected_variables(num_parameter_devices): 212 name = optimizer._name 213 214 if isinstance(optimizer, optimizer_v2.OptimizerV2): 215 variables = VAR_MAP_V2[name] 216 else: 217 variables = VAR_MAP_V1[name] 218 219 extended_variables = [ 220 v + "/replica_{}".format(replica) 221 for v in variables 222 for replica in range(1, num_parameter_devices) 223 ] 224 variables = list(variables) + extended_variables 225 return set(v + ":0" for v in variables) 226 227 self.assertEqual( 228 get_expected_variables(len(distribution.extended.parameter_devices)), 229 set(created_variables)) 230 231 @ds_combinations.generate( 232 combinations.times( 233 combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), 234 combinations.times( 235 optimizer_combinations.distributions_and_v1_and_v2_optimizers(), 236 combinations.combine( 237 mode=["graph", "eager"], 238 # TODO(isaprykin): Allow False here. Currently subsequent 239 # replicas will re-execute UPDATE_OPS of previous replicas. 240 update_ops_in_cross_replica_mode=[True])) + 241 combinations.combine( 242 distribution=[strategy_combinations.tpu_strategy], 243 optimizer_fn=optimizer_combinations.optimizers_v1_and_v2, 244 mode=["graph"], 245 update_ops_in_cross_replica_mode=[False]))) 246 def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, 247 renorm, update_ops_in_cross_replica_mode): 248 """Verifies that moving mean updates are reduced across replicas.""" 249 with distribution.scope(): 250 num_replicas = distribution.num_replicas_in_sync 251 model_fn, dataset_fn, batchnorm = batchnorm_example( 252 optimizer_fn, 253 batch_per_epoch=num_replicas, 254 momentum=momentum, 255 renorm=renorm, 256 update_ops_in_replica_mode=not update_ops_in_cross_replica_mode) 257 258 def step_fn(ctx, inputs): 259 del ctx # Unused 260 fetches = distribution.experimental_local_results( 261 distribution.extended.call_for_each_replica( 262 model_fn, args=(inputs,))) 263 if update_ops_in_cross_replica_mode: 264 fetches += tuple(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) 265 return control_flow_ops.group(fetches) 266 267 iterator = self._get_iterator(distribution, dataset_fn) 268 269 def run_step(): 270 return distribution.extended.experimental_run_steps_on_iterator( 271 step_fn, iterator, iterations=1).run_op 272 273 if not context.executing_eagerly(): 274 with self.cached_session() as sess: 275 run_step = sess.make_callable(run_step()) 276 self.evaluate(variables_lib.global_variables_initializer()) 277 278 expected_moving_means = [0.] * 8 279 280 def averaged_batch_mean(i): 281 # Each batch has shape [16, 8] where the ith element in jth list is 282 # (8 * j + i + replica_id * 100). So the batch mean in each replica is 283 # (60 + i + replica_id * 100). So here comes its batch mean over all 284 # replicas: 285 return 60. + i + (num_replicas - 1.) / 2. * 100. 286 287 for _ in range(10): 288 run_step() 289 moving_means = self.evaluate(batchnorm.moving_mean) 290 291 # We make sure that the moving_mean is updated as if the sample mean is 292 # calculated over all replicas. 293 for i, expected_moving_mean in enumerate(expected_moving_means): 294 expected_moving_means[i] -= (( 295 expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) 296 self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) 297 298 @ds_combinations.generate( 299 combinations.times( 300 combinations.combine(loss_reduction=[ 301 losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, 302 losses_impl.Reduction.SUM_OVER_BATCH_SIZE, 303 losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS 304 ]), 305 combinations.times( 306 combinations.combine(distribution=[ 307 strategy_combinations.one_device_strategy, 308 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 309 strategy_combinations.mirrored_strategy_with_two_gpus 310 ]), 311 combinations.times( 312 combinations.combine(optimizer_fn=optimizer_combinations 313 .gradient_descent_optimizer_v1_fn), 314 combinations.combine( 315 mode=["graph"], use_callable_loss=[True, False]) + 316 combinations.combine( 317 mode=["eager"], use_callable_loss=[True])) + 318 combinations.times( 319 combinations.combine(optimizer_fn=optimizer_combinations 320 .gradient_descent_optimizer_keras_v2_fn), 321 combinations.combine( 322 mode=["graph", "eager"], use_callable_loss=[True]))) + 323 combinations.combine( 324 distribution=[strategy_combinations.tpu_strategy], 325 optimizer_fn=optimizer_combinations 326 .gradient_descent_optimizer_v1_fn, 327 mode=["graph"], 328 use_callable_loss=[True, False]) + combinations.combine( 329 distribution=[strategy_combinations.tpu_strategy], 330 optimizer_fn=optimizer_combinations 331 .gradient_descent_optimizer_keras_v2_fn, 332 mode=["graph"], 333 use_callable_loss=[True]))) 334 def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, 335 use_callable_loss): 336 with distribution.scope(): 337 all_vars = [] 338 339 def model_fn(inputs): 340 x, y = inputs 341 w = variable_scope.get_variable("w", initializer=[[2.]]) 342 all_vars.append(w) 343 344 def loss_fn(): 345 # Use fixed initialization to make the steps deterministic. 346 predict = math_ops.matmul(x, w) 347 loss = losses_impl.mean_squared_error( 348 y, predict, reduction=loss_reduction) 349 if loss_reduction == losses_impl.Reduction.SUM: 350 return loss 351 return loss / distribution.num_replicas_in_sync 352 353 optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate 354 355 if isinstance(optimizer, optimizer_v2.OptimizerV2): 356 return optimizer.minimize(loss_fn, [w]) 357 else: 358 if use_callable_loss: 359 return optimizer.minimize(loss_fn) 360 else: 361 return optimizer.minimize(loss_fn()) 362 363 def dataset_fn(): 364 features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) 365 labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) 366 return dataset_ops.Dataset.zip((features, labels)).repeat() 367 368 def step_fn(ctx, inputs): 369 del ctx # Unused 370 return distribution.group( 371 distribution.extended.call_for_each_replica( 372 model_fn, args=(inputs,))) 373 374 iterator = self._get_iterator(distribution, dataset_fn) 375 376 def run_step(): 377 return distribution.extended.experimental_run_steps_on_iterator( 378 step_fn, iterator, iterations=1).run_op 379 380 if not context.executing_eagerly(): 381 with self.cached_session() as sess: 382 run_step = sess.make_callable(run_step()) 383 self.evaluate(variables_lib.global_variables_initializer()) 384 385 run_step() 386 387 v = all_vars[0] 388 self.assertTrue(all(v is vi for vi in all_vars[1:])) 389 weight = numpy.squeeze(self.evaluate(v)) 390 # Our model is: 391 # predict = x * w 392 # loss = (predict - y)^2 393 # dloss/dpredict = 2*(predict - y) 394 # dloss/dw = 2 * x^T @ (predict - y) 395 # For our batch size of 2, assuming sum loss reduction: 396 # x = [2, 7] 397 # y = [6, 21] 398 # w_initial = 2 399 # predict = [4, 14] 400 # predict - y = [-2, -7] 401 # dloss/dw = 2 <[2, 7], [-2, -7]> = - 2(4 + 49) = -106 402 # So unreplicated the update to w with lr=0.001 is -0.2 * -106 = 0.106 403 # with sum loss reduction, or 0.053 with mean. 404 if loss_reduction == losses_impl.Reduction.SUM: 405 # Note that the "distribution.num_replicas_in_sync" factor will go away 406 # once we split the input across replicas, instead of pulling a complete 407 # batch of input per replica. 408 self.assertNear(weight, 2 + 0.106 * distribution.num_replicas_in_sync, 409 0.0001) 410 else: 411 # One of the mean loss reductions. 412 self.assertNear(weight, 2 + 0.053, 0.0001) 413 414 @ds_combinations.generate( 415 combinations.times( 416 optimizer_combinations.distributions_and_v1_and_v2_optimizers(), 417 combinations.combine(mode=["graph", "eager"]), 418 combinations.combine(is_tpu=[False])) + combinations.combine( 419 distribution=[strategy_combinations.tpu_strategy], 420 optimizer_fn=optimizer_combinations.optimizers_v1_and_v2, 421 mode=["graph"], 422 is_tpu=[True])) 423 def testRunStepsWithOutputContext(self, distribution, optimizer_fn, is_tpu): 424 with distribution.scope(): 425 def dataset_fn(): 426 dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() 427 # TODO(priyag): batch with drop_remainder=True causes shapes to be 428 # fully defined for TPU. Remove this when XLA supports dynamic shapes. 429 return dataset.batch(batch_size=1, drop_remainder=True) 430 431 optimizer = optimizer_fn() 432 layer = core.Dense(1, use_bias=True) 433 434 key1 = "foo" 435 value1 = "bar" 436 437 def model_fn(output_context, x): 438 """A very simple model written by the user.""" 439 def loss_fn(): 440 y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) 441 return y * y 442 443 if isinstance(optimizer, optimizer_v2.OptimizerV2): 444 train_op = optimizer.minimize( 445 loss_fn, lambda: layer.trainable_variables) 446 else: 447 train_op = optimizer.minimize(loss_fn) 448 loss = loss_fn() 449 output_context.set_last_step_output( 450 name="replica_loss_reduced", 451 output=loss, 452 reduce_op=reduce_util.ReduceOp.MEAN) 453 output_context.set_non_tensor_output(key1, value1) 454 return (train_op, loss) 455 456 def step_fn(output_context, inputs): 457 (train_op, loss) = distribution.extended.call_for_each_replica( 458 model_fn, args=(output_context, inputs)) 459 output_context.set_last_step_output( 460 name="cross_replica_loss_reduced", 461 output=loss, 462 reduce_op=reduce_util.ReduceOp.MEAN) 463 output_context.set_last_step_output( 464 name="cross_replica_loss_not_reduced", 465 output=loss) 466 return distribution.group(train_op) 467 468 iterator = self._get_iterator(distribution, dataset_fn) 469 470 def run_step(): 471 initial_loss = lambda: constant_op.constant(1e7) 472 # Initial values corresponding to reduced losses are just single 473 # tensors. But for non reduced losses, we need to have initial 474 # values that are of the same structure as non reduced losses. In 475 # MirroredStrategy, this will be a list of losses, in TPUStrategy 476 # it will be single tensor. Using `call_for_each_replica` followed 477 # by `experimental_local_results` gives us the desired initial 478 # value structure. 479 not_reduced = distribution.experimental_local_results( 480 distribution.extended.call_for_each_replica(initial_loss)) 481 initial_loop_values = { 482 "replica_loss_reduced": initial_loss(), 483 "cross_replica_loss_reduced": initial_loss(), 484 "cross_replica_loss_not_reduced": not_reduced, 485 } 486 ctx = distribution.extended.experimental_run_steps_on_iterator( 487 step_fn, iterator, iterations=2, 488 initial_loop_values=initial_loop_values) 489 490 self.assertEqual({key1: (value1,)}, ctx.non_tensor_outputs) 491 self._verify_loss_output( 492 initial_loss(), 493 loss_output=ctx.last_step_outputs["replica_loss_reduced"], 494 reduced=True, distribution=distribution) 495 self._verify_loss_output( 496 initial_loss(), 497 loss_output=ctx.last_step_outputs["cross_replica_loss_reduced"], 498 reduced=True, distribution=distribution) 499 self._verify_loss_output( 500 initial_loss(), 501 loss_output=ctx.last_step_outputs["cross_replica_loss_not_reduced"], 502 reduced=False, distribution=distribution) 503 return (ctx.run_op, ctx.last_step_outputs["replica_loss_reduced"]) 504 505 if not context.executing_eagerly(): 506 with self.cached_session() as sess: 507 run_step = sess.make_callable(run_step()) 508 self.evaluate(variables_lib.global_variables_initializer()) 509 510 weights, biases, losses = [], [], [] 511 for _ in range(5): 512 _, loss = run_step() 513 losses.append(loss) 514 weights.append(self.evaluate(layer.kernel)) 515 biases.append(self.evaluate(layer.bias)) 516 517 loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) 518 self.assertTrue(loss_is_not_increasing) 519 520 error = abs( 521 numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) 522 error_is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) 523 self.assertTrue(error_is_not_increasing) 524 525 def _verify_loss_output(self, initial_loss, loss_output, reduced, 526 distribution): 527 if not reduced: 528 self.assertLen(distribution.experimental_local_results(loss_output), 529 distribution.num_replicas_in_sync) 530 loss_tensor = distribution.reduce(reduce_util.ReduceOp.MEAN, loss_output, 531 axis=None) 532 else: 533 unwrapped_output = distribution.experimental_local_results(loss_output) 534 self.assertLen(unwrapped_output, 1) 535 loss_tensor = unwrapped_output[0] 536 self.assertEqual(initial_loss.dtype, loss_tensor.dtype) 537 self.assertEqual(initial_loss.shape, loss_tensor.shape) 538 539 @ds_combinations.generate( 540 optimizer_combinations.distributions_and_v2_optimizers()) 541 def test_empty_var_list(self, distribution, optimizer_fn): 542 opt = optimizer_fn() 543 with distribution.scope(): 544 545 def run_fn(): 546 opt.minimize(lambda: constant_op.constant(1.), []) 547 opt.apply_gradients([]) 548 549 distribution.run(run_fn) 550 551 552if __name__ == "__main__": 553 test.main() 554