1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23 24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import input_lib 28from tensorflow.python.distribute import numpy_dataset 29from tensorflow.python.distribute import reduce_util 30from tensorflow.python.distribute import values 31from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 32from tensorflow.python.eager import context 33from tensorflow.python.eager import tape 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import device as tf_device 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import math_ops 42from tensorflow.python.ops import variable_scope as vs 43from tensorflow.python.tpu import device_assignment as device_assignment_lib 44from tensorflow.python.tpu import tpu 45from tensorflow.python.tpu import tpu_strategy_util 46from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 47from tensorflow.python.tpu import training_loop 48from tensorflow.python.tpu.ops import tpu_ops 49from tensorflow.python.util import nest 50from tensorflow.python.util.tf_export import tf_export 51 52 53def get_tpu_system_metadata(tpu_cluster_resolver): 54 """Retrieves TPU system metadata given a TPUClusterResolver.""" 55 master = tpu_cluster_resolver.master() 56 57 # pylint: disable=protected-access 58 cluster_spec = tpu_cluster_resolver.cluster_spec() 59 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None 60 tpu_system_metadata = ( 61 tpu_system_metadata_lib._query_tpu_system_metadata( 62 master, 63 cluster_def=cluster_def, 64 query_topology=False)) 65 66 return tpu_system_metadata 67 68 69# TODO(jhseu): Deduplicate with MirroredStrategy? 70def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring 71 strategy, device_map, logical_device, real_mirrored_creator, 72 *args, **kwargs): 73 # Figure out what collections this variable should be added to. 74 # We'll add the TPUMirroredVariable to those collections instead. 75 var_collections = kwargs.pop("collections", None) 76 if var_collections is None: 77 var_collections = [ops.GraphKeys.GLOBAL_VARIABLES] 78 kwargs["collections"] = [] 79 80 # TODO(jhseu): Should we have different behavior for different 81 # synchronization settings? 82 83 # Get aggregation value 84 # TODO(jhseu): Support aggregation in a replica context. 85 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 86 if aggregation not in [ 87 vs.VariableAggregation.NONE, 88 vs.VariableAggregation.SUM, 89 vs.VariableAggregation.MEAN, 90 vs.VariableAggregation.ONLY_FIRST_REPLICA, 91 ]: 92 raise ValueError("Invalid variable aggregation mode: {} for variable: {}" 93 .format(aggregation, kwargs["name"])) 94 95 # Ignore user-specified caching device, not needed for mirrored variables. 96 kwargs.pop("caching_device", None) 97 98 # TODO(josh11b,apassos): It would be better if variable initialization 99 # was never recorded on the tape instead of having to do this manually 100 # here. 101 with tape.stop_recording(): 102 devices = device_map.logical_to_actual_devices(logical_device) 103 value_list = real_mirrored_creator(devices, *args, **kwargs) 104 result = values.TPUMirroredVariable( 105 strategy, device_map, value_list, aggregation, 106 logical_device=logical_device) 107 108 if not (context.executing_eagerly() or ops.inside_function()): 109 g = ops.get_default_graph() 110 # If "trainable" is True, next_creator() will add the member variables 111 # to the TRAINABLE_VARIABLES collection, so we manually remove 112 # them and replace with the MirroredVariable. We can't set 113 # "trainable" to False for next_creator() since that causes functions 114 # like implicit_gradients to skip those variables. 115 if kwargs.get("trainable", True): 116 var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 117 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 118 for v in value_list: 119 l.remove(v) 120 g.add_to_collections(var_collections, result) 121 return result 122 123 124@tf_export("distribute.experimental.TPUStrategy") 125class TPUStrategy(distribute_lib.DistributionStrategy): 126 """TPU distribution strategy implementation.""" 127 128 def __init__(self, 129 tpu_cluster_resolver=None, 130 steps_per_run=None, 131 device_assignment=None): 132 """Initializes the TPUStrategy object. 133 134 Args: 135 tpu_cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 136 which provides information about the TPU cluster. 137 steps_per_run: Number of steps to run on device before returning to the 138 host. Note that this can have side-effects on performance, hooks, 139 metrics, summaries etc. 140 This parameter is only used when Distribution Strategy is used with 141 estimator or keras. 142 device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify 143 the placement of replicas on the TPU cluster. Currently only supports 144 the usecase of using a single core within a TPU cluster. 145 """ 146 super(TPUStrategy, self).__init__(TPUExtended( 147 self, tpu_cluster_resolver, steps_per_run, device_assignment)) 148 149 @property 150 def steps_per_run(self): 151 """DEPRECATED: use .extended.steps_per_run instead.""" 152 return self._extended.steps_per_run 153 154 # TODO(cjfj): Modify `_call_for_each_replica` in `TPUExtended` such that this 155 # can use the default implementation. 156 # This implementation runs a single step. It does not use infeed or outfeed. 157 def experimental_run_v2(self, fn, args=(), kwargs=None): 158 """See base class.""" 159 if context.executing_eagerly() and not ops.inside_function(): 160 raise NotImplementedError( 161 "Eager mode not supported in TPUStrategy outside TF functions.") 162 163 if kwargs is None: 164 kwargs = {} 165 166 result = [None] 167 def replicated_fn(replica_id, replica_args, replica_kwargs): 168 """Wraps user function to provide replica ID and `Tensor` inputs.""" 169 with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id): 170 result[0] = fn(*replica_args, **replica_kwargs) 171 return result[0] 172 173 replicate_inputs = [] # By replica. 174 for i in range(self.num_replicas_in_sync): 175 replicate_inputs.append( 176 [constant_op.constant(i, dtype=dtypes.int32), 177 values.select_replica(i, args), 178 values.select_replica(i, kwargs)]) 179 180 with self.scope(): 181 replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs) 182 183 # Workaround for `tpu.replicate` behaviour when single `Tensor` returned. 184 replicate_outputs = [ 185 nest.pack_sequence_as(result[0], nest.flatten(replica_outputs)) 186 for replica_outputs in replicate_outputs] 187 188 device_map = self.extended._device_map # pylint: disable=protected-access 189 return values.regroup(device_map, replicate_outputs) 190 191 192class TPUExtended(distribute_lib.DistributionStrategyExtended): 193 """Implementation of TPUStrategy.""" 194 195 def __init__(self, 196 container_strategy, 197 tpu_cluster_resolver=None, 198 steps_per_run=None, 199 device_assignment=None): 200 super(TPUExtended, self).__init__(container_strategy) 201 202 if tpu_cluster_resolver is None: 203 tpu_cluster_resolver = TPUClusterResolver("") 204 205 if steps_per_run is None: 206 # TODO(frankchn): Warn when we are being used by DS/Keras and this is 207 # not specified. 208 steps_per_run = 1 209 210 self._tpu_cluster_resolver = tpu_cluster_resolver 211 self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) 212 self._device_assignment = device_assignment 213 214 # Device assignment is currently only supported for 1 core case. 215 if self._device_assignment: 216 assert isinstance(self._device_assignment, 217 device_assignment_lib.DeviceAssignment) 218 if self._device_assignment.num_replicas != 1: 219 raise ValueError("Device assignment is only supported for a single " 220 "core single replica case currently.") 221 if self._device_assignment.num_cores_per_replica != 1: 222 raise ValueError("Device assignment is only supported for a single " 223 "core single replica case currently.") 224 if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]): 225 raise ValueError("Device assignment is only supported for a single " 226 "core single replica case currently.") 227 228 # TODO(jhseu): Switch to DeviceAssignment to support pods and model 229 # parallelism. 230 self._tpu_devices = [d.name for d in self._tpu_metadata.devices 231 if "device:TPU:" in d.name] 232 233 self._host_device = tpu_strategy_util.get_first_tpu_host_device( 234 self._tpu_cluster_resolver) 235 236 # Only create variables for the number of replicas we're running. 237 self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] 238 self._device_map = values.ReplicaDeviceMap(self._tpu_devices) 239 240 # Preload the data onto the TPUs. 241 input_worker_devices = collections.OrderedDict() 242 for tpu_device in self._tpu_devices: 243 host_device = _get_host_for_device(tpu_device) 244 input_worker_devices.setdefault(host_device, []) 245 input_worker_devices[host_device].append(tpu_device) 246 self._input_workers = input_lib.InputWorkers( 247 self._device_map, tuple(input_worker_devices.items())) 248 249 # TODO(sourabhbajaj): Remove this once performance of running one step 250 # at a time is comparable to multiple steps. 251 self.steps_per_run = steps_per_run 252 self._require_static_shapes = True 253 254 def _validate_colocate_with_variable(self, colocate_with_variable): 255 values.validate_colocate_tpu_variable(colocate_with_variable, self) 256 257 def _make_dataset_iterator(self, dataset): 258 """Make iterators for each of the TPU hosts.""" 259 return input_lib.DatasetIterator(dataset, self._input_workers, 260 self._num_replicas_in_sync) 261 262 def _make_input_fn_iterator( 263 self, 264 input_fn, 265 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 266 input_contexts = [] 267 num_workers = self._input_workers.num_workers 268 for i in range(num_workers): 269 input_contexts.append(distribute_lib.InputContext( 270 num_input_pipelines=num_workers, 271 input_pipeline_id=i, 272 num_replicas_in_sync=self._num_replicas_in_sync)) 273 return input_lib.InputFunctionIterator( 274 input_fn, self._input_workers, input_contexts) 275 276 def _experimental_make_numpy_dataset(self, numpy_input, session): 277 return numpy_dataset.one_host_numpy_dataset( 278 numpy_input, numpy_dataset.SingleDevice(self._host_device), 279 session) 280 281 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 282 # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have 283 # a mechanism to infer the outputs of `fn`. Pending b/110550782. 284 def _experimental_run_steps_on_iterator( 285 self, fn, multi_worker_iterator, iterations, initial_loop_values=None): 286 output_shapes = multi_worker_iterator.output_shapes 287 shapes = nest.flatten(output_shapes) 288 if any(not s.is_fully_defined() for s in shapes): 289 raise ValueError( 290 "TPU currently requires fully defined shapes. Either use " 291 "set_shape() on the input tensors or use " 292 "dataset.batch(..., drop_remainder=True).") 293 294 # Wrap `fn` for repeat. 295 if initial_loop_values is None: 296 initial_loop_values = {} 297 initial_loop_values = nest.flatten(initial_loop_values) 298 ctx = input_lib.MultiStepContext() 299 300 def run_fn(inputs): 301 """Single step on the TPU device.""" 302 fn_result = fn(ctx, inputs) 303 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 304 if flat_last_step_outputs: 305 with ops.control_dependencies([fn_result]): 306 return [array_ops.identity(f) for f in flat_last_step_outputs] 307 else: 308 return fn_result 309 310 # We capture the control_flow_context at this point, before we run `fn` 311 # inside a while_loop and TPU replicate context. This is useful in cases 312 # where we might need to exit these contexts and get back to the outer 313 # context to do some things, for e.g. create an op which should be 314 # evaluated only once at the end of the loop on the host. One such usage 315 # is in creating metrics' value op. 316 self._outer_control_flow_context = ( 317 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 318 319 def rewrite_fn(*args): 320 """The rewritten step fn running on TPU.""" 321 del args 322 323 per_replica_inputs = multi_worker_iterator.get_next() 324 replicate_inputs = [] 325 for replica_id in range(self._num_replicas_in_sync): 326 select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop 327 replicate_inputs.append((nest.map_structure( 328 select_replica, per_replica_inputs),)) 329 330 replicate_outputs = tpu.replicate(run_fn, replicate_inputs) 331 332 # If run_fn has tensor outputs, tpu.replicate returns a list of list. We 333 # will flatten it in this case. If run_fn has no tensor outputs, 334 # tpu.replicate returns a list of no_ops, we will keep the output as it 335 # is. 336 if isinstance(replicate_outputs[0], list): 337 replicate_outputs = nest.flatten(replicate_outputs) 338 339 return replicate_outputs 340 341 # TODO(sourabhbajaj): The input to while loop should be based on the 342 # output type of the step_fn 343 assert isinstance(initial_loop_values, list) 344 initial_loop_values = initial_loop_values * self._num_replicas_in_sync 345 346 # Put the while loop op on TPU host 0. 347 with ops.device(self._host_device): 348 if self.steps_per_run == 1: 349 replicate_outputs = rewrite_fn() 350 else: 351 replicate_outputs = training_loop.repeat(iterations, rewrite_fn, 352 initial_loop_values) 353 354 del self._outer_control_flow_context 355 ctx.run_op = control_flow_ops.group(replicate_outputs) 356 357 if isinstance(replicate_outputs, list): 358 # Filter out any ops from the outputs, typically this would be the case 359 # when there were no tensor outputs. 360 last_step_tensor_outputs = [ 361 x for x in replicate_outputs if not isinstance(x, ops.Operation) 362 ] 363 364 # Outputs are currently of the structure (flattened) 365 # [output0_device0, output1_device0, output2_device0, 366 # output0_device1, output1_device1, output2_device1, 367 # ...] 368 # Convert this to the following structure instead: (grouped by output) 369 # [[output0_device0, output0_device1], 370 # [output1_device0, output1_device1], 371 # [output2_device0, output2_device1]] 372 output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync 373 last_step_tensor_outputs = [ 374 last_step_tensor_outputs[i::output_num] for i in range(output_num) 375 ] 376 else: 377 # no tensors returned. 378 last_step_tensor_outputs = [] 379 380 _set_last_step_outputs(ctx, last_step_tensor_outputs) 381 return ctx 382 383 def _call_for_each_replica(self, fn, args, kwargs): 384 # TODO(jhseu): Consider making it so call_for_each_replica implies that 385 # we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly. 386 with _TPUReplicaContext(self._container_strategy()): 387 return fn(*args, **kwargs) 388 389 def _experimental_initialize_system(self): 390 """Experimental method added to be used by Estimator. 391 392 This is a private method only to be used by Estimator. Other frameworks 393 should directly be calling `tf.contrib.distribute.initialize_tpu_system` 394 """ 395 tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) 396 397 def _create_variable(self, next_creator, *args, **kwargs): 398 """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" 399 colocate_with = kwargs.pop("colocate_with", None) 400 if colocate_with is None: 401 device_map = self._device_map 402 logical_device = 0 # TODO(josh11b): Get logical device from scope here. 403 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 404 with ops.device(colocate_with.device): 405 return next_creator(*args, **kwargs) 406 else: 407 device_map = colocate_with.device_map 408 logical_device = colocate_with.logical_device 409 410 def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring 411 value_list = [] 412 for i, d in enumerate(devices): 413 with ops.device(d): 414 if i > 0: 415 # Give replicas meaningful distinct names: 416 var0name = value_list[0].name.split(":")[0] 417 # We append a / to variable names created on replicas with id > 0 to 418 # ensure that we ignore the name scope and instead use the given 419 # name as the absolute name of the variable. 420 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 421 # Initialize replicas with the same value: 422 if context.executing_eagerly() or ops.inside_function(): 423 with ops.init_scope(): 424 kwargs["initial_value"] = array_ops.identity( 425 value_list[0].value()) 426 else: 427 def initial_value_fn(device=d): 428 with ops.device(device): 429 return array_ops.identity(value_list[0].initial_value) 430 kwargs["initial_value"] = initial_value_fn 431 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 432 v = next_creator(*args, **kwargs) 433 assert not isinstance(v, values.TPUMirroredVariable) 434 value_list.append(v) 435 return value_list 436 437 return _create_tpu_mirrored_variable( 438 self._container_strategy(), device_map, logical_device, 439 _real_mirrored_creator, *args, **kwargs) 440 441 def _reduce_to(self, reduce_op, value, destinations): 442 if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access 443 if reduce_op == reduce_util.ReduceOp.MEAN: 444 # TODO(jhseu): Revisit once we support model-parallelism. 445 value *= (1. / self._num_replicas_in_sync) 446 elif reduce_op != reduce_util.ReduceOp.SUM: 447 raise NotImplementedError( 448 "Currently only support sum & mean in TPUStrategy.") 449 return tpu_ops.cross_replica_sum(value) 450 451 if not isinstance(value, values.DistributedValues): 452 # This function handles reducing values that are not PerReplica or 453 # Mirrored values. For example, the same value could be present on all 454 # replicas in which case `value` would be a single value or value could 455 # be 0. 456 return cross_device_ops_lib.reduce_non_distributed_value( 457 reduce_op, self._device_map, value, destinations) 458 459 devices = cross_device_ops_lib.get_devices_from(destinations) 460 if len(devices) != 1: 461 raise ValueError("Multiple devices are not supported for TPUStrategy") 462 463 # Always performs the reduction on the TPU host. 464 with ops.device(self._host_device): 465 output = math_ops.add_n(value.values) 466 if reduce_op == reduce_util.ReduceOp.MEAN: 467 output *= (1. / len(value.values)) 468 469 # If necessary, copy to requested destination. 470 dest_canonical = device_util.canonicalize(devices[0]) 471 host_canonical = device_util.canonicalize(self._host_device) 472 473 if dest_canonical != host_canonical: 474 with ops.device(devices[0]): 475 output = array_ops.identity(output) 476 477 return output 478 479 def _update(self, var, fn, args, kwargs, group): 480 assert isinstance(var, values.TPUMirroredVariable) 481 if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access 482 if group: 483 return fn(var, *args, **kwargs) 484 else: 485 return (fn(var, *args, **kwargs),) 486 487 # Otherwise, we revert to MirroredStrategy behavior and update each variable 488 # directly. 489 updates = [] 490 for i, (d, v) in enumerate(zip(var.devices, var.values)): 491 name = "update_%d" % i 492 with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): 493 # If args and kwargs are not mirrored, the value is returned as is. 494 updates.append(fn(v, 495 *values.select_device_mirrored(d, args), 496 **values.select_device_mirrored(d, kwargs))) 497 return values.update_regroup(self, self._device_map, updates, group) 498 499 def read_var(self, var): 500 assert isinstance(var, values.TPUMirroredVariable) 501 return var.read_value() 502 503 def _local_results(self, val): 504 if isinstance(val, values.DistributedValues): 505 # Return in a deterministic order. 506 return tuple(val.get(device=d) for d in sorted(val.devices)) 507 elif isinstance(val, list): 508 # TODO(josh11b): We need to remove this case; per device values should 509 # be represented using a PerReplica wrapper instead of a list with 510 # one entry per device. 511 return tuple(val) 512 elif isinstance(val, values.TPUMirroredVariable): 513 # pylint: disable=protected-access 514 if values._enclosing_tpu_context() is not None: 515 return (val,) 516 return val.values 517 return (val,) 518 519 def value_container(self, value): 520 return value 521 522 def _broadcast_to(self, tensor, destinations): 523 del destinations 524 return tensor 525 526 @property 527 def num_hosts(self): 528 if self._device_assignment is None: 529 return self._tpu_metadata.num_hosts 530 531 return len(set([self._device_assignment.host_device(r) 532 for r in range(self._device_assignment.num_replicas)])) 533 534 @property 535 def num_replicas_per_host(self): 536 if self._device_assignment is None: 537 return self._tpu_metadata.num_of_cores_per_host 538 539 # TODO(sourabhbajaj): Remove this method we use inputs and remove infeed 540 # as the computation of num_replicas_per_host is not a constant 541 # when using device_assignment. This is a temporary workaround to support 542 # StatefulRNN as everything is 1 in that case. 543 # This method needs to take host_id as input for correct computation. 544 max_models_per_host = (self._tpu_metadata.num_of_cores_per_host // 545 self._device_assignment.num_cores_per_replica) 546 models_per_host = min(self._device_assignment.num_replicas, 547 max_models_per_host) 548 return models_per_host * self._device_assignment.num_cores_per_replica 549 550 @property 551 def _num_replicas_in_sync(self): 552 if self._device_assignment is None: 553 return self._tpu_metadata.num_cores 554 return (self._device_assignment.num_replicas * 555 self._device_assignment.num_cores_per_replica) 556 557 @property 558 def experimental_between_graph(self): 559 return False 560 561 @property 562 def experimental_should_init(self): 563 return True 564 565 @property 566 def should_checkpoint(self): 567 return True 568 569 @property 570 def should_save_summary(self): 571 return True 572 573 @property 574 def worker_devices(self): 575 return self._tpu_devices 576 577 @property 578 def parameter_devices(self): 579 return self._tpu_devices 580 581 def non_slot_devices(self, var_list): 582 return self._host_device 583 584 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 585 del colocate_with 586 with ops.device(self._host_device), distribute_lib.UpdateContext( 587 self._host_device): 588 result = fn(*args, **kwargs) 589 if group: 590 return result 591 else: 592 return nest.map_structure(self._local_results, result) 593 594 def _configure(self, 595 session_config=None, 596 cluster_spec=None, 597 task_type=None, 598 task_id=None): 599 del cluster_spec, task_type, task_id 600 if session_config: 601 session_config.CopyFrom(self._update_config_proto(session_config)) 602 603 def _update_config_proto(self, config_proto): 604 updated_config = copy.deepcopy(config_proto) 605 updated_config.isolate_session_state = True 606 cluster_spec = self._tpu_cluster_resolver.cluster_spec() 607 if cluster_spec: 608 updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 609 return updated_config 610 611 # TODO(priyag): Delete this once all strategies use global batch size. 612 @property 613 def _global_batch_size(self): 614 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 615 616 `make_input_fn_iterator` assumes per-replica batching. 617 618 Returns: 619 Boolean. 620 """ 621 return True 622 623 624class _TPUReplicaContext(distribute_lib.ReplicaContext): 625 """Replication Context class for TPU Strategy.""" 626 627 # TODO(sourabhbajaj): Call for each replica should be updating this. 628 # TODO(b/118385803): Always properly initialize replica_id. 629 def __init__(self, strategy, replica_id_in_sync_group=None): 630 if replica_id_in_sync_group is None: 631 replica_id_in_sync_group = constant_op.constant(0, dtypes.int32) 632 distribute_lib.ReplicaContext.__init__( 633 self, strategy, replica_id_in_sync_group=replica_id_in_sync_group) 634 635 @property 636 def devices(self): 637 distribute_lib.require_replica_context(self) 638 ds = self._strategy 639 replica_id = tensor_util.constant_value(self._replica_id_in_sync_group) 640 641 if replica_id is None: # Non-constant `Tensor` inside `tpu.replicate`. 642 # TODO(cjfj): Return other devices when model parallelism is supported. 643 return (tpu.core(0),) 644 else: 645 return (ds.extended.worker_devices[replica_id],) 646 647 648def _get_host_for_device(device): 649 spec = tf_device.DeviceSpec.from_string(device) 650 return tf_device.DeviceSpec( 651 job=spec.job, replica=spec.replica, task=spec.task, 652 device_type="CPU", device_index=0).to_string() 653 654 655def _set_last_step_outputs(ctx, last_step_tensor_outputs): 656 """Sets the last step outputs on the given context.""" 657 # Convert replicate_outputs to the original dict structure of 658 # last_step_outputs. 659 last_step_tensor_outputs_dict = nest.pack_sequence_as( 660 ctx.last_step_outputs, last_step_tensor_outputs) 661 662 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 663 output = last_step_tensor_outputs_dict[name] 664 # For outputs that have already been reduced, take the first value 665 # from the list as each value should be the same. Else return the full 666 # list of values. 667 # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica 668 # value. 669 if reduce_op is not None: 670 # TODO(priyag): Should this return the element or a list with 1 element 671 last_step_tensor_outputs_dict[name] = output[0] 672 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 673