1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A tf.distribute.Strategy for running on a single device.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import device_util 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.distribute import distribute_utils 24from tensorflow.python.distribute import input_lib 25from tensorflow.python.distribute import numpy_dataset 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.util import nest 31from tensorflow.python.util.tf_export import tf_export 32 33 34# TODO(josh11b): Do we wrap values in types to generate errors if you are 35# doing something that won't work with other DistributionStrategy 36# implementations? 37 38 39@tf_export("distribute.OneDeviceStrategy", v1=[]) 40class OneDeviceStrategy(distribute_lib.Strategy): 41 """A distribution strategy for running on a single device. 42 43 Using this strategy will place any variables created in its scope on the 44 specified device. Input distributed through this strategy will be 45 prefetched to the specified device. Moreover, any functions called via 46 `strategy.run` will also be placed on the specified device 47 as well. 48 49 Typical usage of this strategy could be testing your code with the 50 tf.distribute.Strategy API before switching to other strategies which 51 actually distribute to multiple devices/machines. 52 53 For example: 54 ``` 55 strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") 56 57 with strategy.scope(): 58 v = tf.Variable(1.0) 59 print(v.device) # /job:localhost/replica:0/task:0/device:GPU:0 60 61 def step_fn(x): 62 return x * 2 63 64 result = 0 65 for i in range(10): 66 result += strategy.run(step_fn, args=(i,)) 67 print(result) # 90 68 ``` 69 """ 70 71 def __init__(self, device): 72 """Creates a `OneDeviceStrategy`. 73 74 Args: 75 device: Device string identifier for the device on which the variables 76 should be placed. See class docs for more details on how the device is 77 used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0" 78 """ 79 super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) 80 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 81 "OneDeviceStrategy") 82 83 def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation 84 """Distributes a tf.data.Dataset instance provided via dataset. 85 86 In this case, there is only one device, so this is only a thin wrapper 87 around the input dataset. It will, however, prefetch the input data to the 88 specified device. The returned distributed dataset can be iterated over 89 similar to how regular datasets can. 90 91 NOTE: Currently, the user cannot add any more transformations to a 92 distributed dataset. 93 94 Example: 95 ``` 96 strategy = tf.distribute.OneDeviceStrategy() 97 dataset = tf.data.Dataset.range(10).batch(2) 98 dist_dataset = strategy.experimental_distribute_dataset(dataset) 99 for x in dist_dataset: 100 print(x) # [0, 1], [2, 3],... 101 ``` 102 Args: 103 dataset: `tf.data.Dataset` to be prefetched to device. 104 options: `tf.distribute.InputOptions` used to control options on how this 105 dataset is distributed. 106 Returns: 107 A "distributed `Dataset`" that the caller can iterate over. 108 """ 109 return super(OneDeviceStrategy, self).experimental_distribute_dataset( 110 dataset, options) 111 112 def distribute_datasets_from_function( 113 self, 114 dataset_fn, # pylint: disable=useless-super-delegation 115 options=None): 116 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 117 118 `dataset_fn` will be called once for each worker in the strategy. In this 119 case, we only have one worker and one device so `dataset_fn` is called 120 once. 121 122 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 123 information about batching and input replication can be accessed: 124 125 ``` 126 def dataset_fn(input_context): 127 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 128 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 129 return d.shard( 130 input_context.num_input_pipelines, input_context.input_pipeline_id) 131 132 inputs = strategy.distribute_datasets_from_function(dataset_fn) 133 134 for batch in inputs: 135 replica_results = strategy.run(replica_fn, args=(batch,)) 136 ``` 137 138 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 139 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 140 the global batch size. This may be computed using 141 `input_context.get_per_replica_batch_size`. 142 143 Args: 144 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 145 returning a `tf.data.Dataset`. 146 options: `tf.distribute.InputOptions` used to control options on how this 147 dataset is distributed. 148 149 Returns: 150 A "distributed `Dataset`", which the caller can iterate over like regular 151 datasets. 152 """ 153 return super(OneDeviceStrategy, 154 self).distribute_datasets_from_function(dataset_fn, options) 155 156 def experimental_local_results(self, value): # pylint: disable=useless-super-delegation 157 """Returns the list of all local per-replica values contained in `value`. 158 159 In `OneDeviceStrategy`, the `value` is always expected to be a single 160 value, so the result is just the value in a tuple. 161 162 Args: 163 value: A value returned by `experimental_run()`, `run()`, 164 `extended.call_for_each_replica()`, or a variable created in `scope`. 165 166 Returns: 167 A tuple of values contained in `value`. If `value` represents a single 168 value, this returns `(value,).` 169 """ 170 return super(OneDeviceStrategy, self).experimental_local_results(value) 171 172 def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation 173 """Run `fn` on each replica, with the given arguments. 174 175 In `OneDeviceStrategy`, `fn` is simply called within a device scope for the 176 given device, with the provided arguments. 177 178 Args: 179 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 180 args: (Optional) Positional arguments to `fn`. 181 kwargs: (Optional) Keyword arguments to `fn`. 182 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 183 the options to run `fn`. 184 185 Returns: 186 Return value from running `fn`. 187 """ 188 return super(OneDeviceStrategy, self).run(fn, args, kwargs, options) 189 190 def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation 191 """Reduce `value` across replicas. 192 193 In `OneDeviceStrategy`, there is only one replica, so if axis=None, value 194 is simply returned. If axis is specified as something other than None, 195 such as axis=0, value is reduced along that axis and returned. 196 197 Example: 198 ``` 199 t = tf.range(10) 200 201 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy() 202 # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 203 204 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy() 205 # result: 45 206 ``` 207 208 Args: 209 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 210 be combined. 211 value: A "per replica" value, e.g. returned by `run` to 212 be combined into a single tensor. 213 axis: Specifies the dimension to reduce along within each 214 replica's tensor. Should typically be set to the batch dimension, or 215 `None` to only reduce across replicas (e.g. if the tensor has no batch 216 dimension). 217 218 Returns: 219 A `Tensor`. 220 """ 221 return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis) 222 223 def scope(self): # pylint: disable=useless-super-delegation 224 """Returns a context manager selecting this Strategy as current. 225 226 Inside a `with strategy.scope():` code block, this thread 227 will use a variable creator set by `strategy`, and will 228 enter its "cross-replica context". 229 230 In `OneDeviceStrategy`, all variables created inside `strategy.scope()` 231 will be on `device` specified at strategy construction time. 232 See example in the docs for this class. 233 234 Returns: 235 A context manager to use for creating variables with this strategy. 236 """ 237 return super(OneDeviceStrategy, self).scope() 238 239 240@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring 241class OneDeviceStrategyV1(distribute_lib.StrategyV1): 242 243 __doc__ = OneDeviceStrategy.__doc__.replace( 244 "For example:\n ```", 245 "For example:\n ```\n tf.enable_eager_execution()") 246 247 def __init__(self, device): 248 super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device)) 249 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 250 "OneDeviceStrategy") 251 __init__.__doc__ = OneDeviceStrategy.__init__.__doc__ 252 253 254# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs. 255class OneDeviceExtended(distribute_lib.StrategyExtendedV1): 256 """Implementation of OneDeviceStrategy.""" 257 258 def __init__(self, container_strategy, device): 259 super(OneDeviceExtended, self).__init__(container_strategy) 260 self._device = device_util.resolve(device) 261 self._input_device = device_util.get_host_for_device(self._device) 262 263 def _input_workers_with_options(self, options=None): 264 if not options or options.experimental_fetch_to_device: 265 return input_lib.InputWorkers([(self._input_device, (self._device,))]) 266 else: 267 return input_lib.InputWorkers([(self._input_device, 268 (self._input_device,))]) 269 270 @property 271 def _input_workers(self): 272 return self._input_workers_with_options() 273 274 def _create_variable(self, next_creator, **kwargs): 275 colocate_with = kwargs.pop("colocate_with", None) 276 if colocate_with is None: 277 with ops.device(self._device): 278 return next_creator(**kwargs) 279 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 280 with ops.device(colocate_with.device): 281 return next_creator(**kwargs) 282 else: 283 with ops.colocate_with(colocate_with): 284 return next_creator(**kwargs) 285 286 def _validate_colocate_with_variable(self, colocate_with_variable): 287 distribute_utils.validate_colocate(colocate_with_variable, self) 288 289 def _make_dataset_iterator(self, dataset): 290 """Make iterator from dataset without splitting the batch.""" 291 # Note that split_batch_by argument is not passed because it is always 1 in 292 # this strategy, and adding it adds unnecessary overhead to the dataset. 293 return input_lib.DatasetIterator(dataset, self._input_workers, 294 self._container_strategy()) 295 296 def _make_input_fn_iterator( 297 self, 298 input_fn, 299 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 300 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 301 [distribute_lib.InputContext()], 302 self._container_strategy()) 303 304 def _experimental_make_numpy_dataset(self, numpy_input, session): 305 return numpy_dataset.one_host_numpy_dataset( 306 numpy_input, numpy_dataset.SingleDevice(self._input_device), session) 307 308 def _broadcast_to(self, tensor, destinations): 309 del destinations 310 return tensor 311 312 def _experimental_distribute_dataset(self, dataset, options): 313 # Note that split_batch_by argument is not passed because it is always 1 in 314 # this strategy, and adding it adds unnecessary overhead to the dataset. 315 if (options and options.experimental_replication_mode == 316 distribute_lib.InputReplicationMode.PER_REPLICA): 317 raise NotImplementedError( 318 "InputReplicationMode.PER_REPLICA " 319 "is only supported in " 320 "`experimental_distribute_datasets_from_function`." 321 ) 322 return input_lib.get_distributed_dataset( 323 dataset, 324 self._input_workers_with_options(options), 325 self._container_strategy(), 326 options=options) 327 328 def _distribute_datasets_from_function(self, dataset_fn, options): 329 if (options and options.experimental_replication_mode == 330 distribute_lib.InputReplicationMode.PER_REPLICA): 331 raise NotImplementedError( 332 "InputReplicationMode.PER_REPLICA " 333 "is only supported in " 334 "`experimental_distribute_datasets_from_function` " 335 "of tf.distribute.MirroredStrategy") 336 return input_lib.get_distributed_datasets_from_function( 337 dataset_fn, 338 self._input_workers_with_options(options), 339 [distribute_lib.InputContext()], 340 self._container_strategy(), 341 options=options) 342 343 def _experimental_distribute_values_from_function(self, value_fn): 344 # TODO(b/137795644): This should return a PerReplica value but other 345 # methods like run in OneDeviceStrategy need to be modified 346 # to do the same. 347 return value_fn(distribute_lib.ValueContext()) 348 349 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 350 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 351 initial_loop_values=None): 352 if initial_loop_values is None: 353 initial_loop_values = {} 354 initial_loop_values = nest.flatten(initial_loop_values) 355 356 ctx = input_lib.MultiStepContext() 357 def body(i, *args): 358 """A wrapper around `fn` to create the while loop body.""" 359 del args 360 fn_result = fn(ctx, iterator.get_next()) 361 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 362 with ops.control_dependencies([fn_result]): 363 return [i + 1] + flat_last_step_outputs 364 365 # We capture the control_flow_context at this point, before we run `fn` 366 # inside a while_loop. This is useful in cases where we might need to exit 367 # these contexts and get back to the outer context to do some things, for 368 # e.g. create an op which should be evaluated only once at the end of the 369 # loop on the host. One such usage is in creating metrics' value op. 370 self._outer_control_flow_context = ( 371 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 372 373 # TODO(priyag): Use max_iterations instead of an explicit counter. 374 cond = lambda i, *args: i < iterations 375 i = constant_op.constant(0) 376 loop_result = control_flow_ops.while_loop( 377 cond, body, [i] + initial_loop_values, name="", 378 parallel_iterations=1, back_prop=False, swap_memory=False, 379 return_same_structure=True) 380 del self._outer_control_flow_context 381 382 ctx.run_op = control_flow_ops.group(loop_result) 383 384 # Convert the last_step_outputs from a list to the original dict structure 385 # of last_step_outputs. 386 last_step_tensor_outputs = loop_result[1:] 387 last_step_tensor_outputs_dict = nest.pack_sequence_as( 388 ctx.last_step_outputs, last_step_tensor_outputs) 389 390 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 391 return ctx 392 393 def _call_for_each_replica(self, fn, args, kwargs): 394 strategy = self._container_strategy() 395 with ops.device(self._device), _OneDeviceReplicaContext(strategy): 396 return fn(*args, **kwargs) 397 398 def _reduce_to(self, reduce_op, value, destinations, options): 399 del reduce_op, destinations, options 400 return value 401 402 def _gather_to_implementation(self, value, destinations, axis, options): 403 del destinations, axis, options 404 return value 405 406 def _update(self, var, fn, args, kwargs, group): 407 # The implementations of _update() and _update_non_slot() are identical 408 # except _update() passes `var` as the first argument to `fn()`. 409 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 410 411 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 412 del colocate_with 413 with ops.device(self._device), distribute_lib.UpdateContext(self._device): 414 result = fn(*args, **kwargs) 415 if group: 416 return result 417 else: 418 return nest.map_structure(self._local_results, result) 419 420 def read_var(self, replica_local_var): 421 """Read the aggregate value of a replica-local variable.""" 422 return array_ops.identity(replica_local_var) 423 424 def _local_results(self, value): 425 return (value,) 426 427 def value_container(self, value): 428 return value 429 430 def _in_multi_worker_mode(self): 431 """Whether this strategy indicates working in multi-worker settings.""" 432 return False 433 434 @property 435 def _num_replicas_in_sync(self): 436 return 1 437 438 @property 439 def worker_devices(self): 440 return (self._device,) 441 442 @property 443 def parameter_devices(self): 444 return (self._device,) 445 446 def non_slot_devices(self, var_list): 447 del var_list 448 return (self._device,) 449 450 @property 451 def experimental_should_init(self): 452 return True 453 454 @property 455 def experimental_between_graph(self): 456 return False 457 458 @property 459 def should_checkpoint(self): 460 return True 461 462 @property 463 def should_save_summary(self): 464 return True 465 466 # TODO(priyag): Delete this once all strategies use global batch size. 467 @property 468 def _global_batch_size(self): 469 """Global and per-replica batching are equivalent for OneDeviceStrategy.""" 470 return True 471 472 @property 473 def _support_per_replica_values(self): 474 return False 475 476 def _get_local_replica_id(self, replica_id_in_sync_group): 477 return replica_id_in_sync_group 478 479 480class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): 481 """ReplicaContext for OneDeviceStrategy.""" 482 483 def __init__(self, strategy): 484 distribute_lib.ReplicaContext.__init__( 485 self, strategy, replica_id_in_sync_group=0) 486 487 @property 488 def devices(self): 489 return self._strategy.extended.worker_devices 490