1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import contextlib 23import weakref 24import six 25 26from tensorflow.python.distribute import device_util 27from tensorflow.python.distribute import distribute_lib 28from tensorflow.python.distribute import distribution_strategy_context 29from tensorflow.python.distribute import reduce_util 30from tensorflow.python.eager import context 31from tensorflow.python.eager import tape 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_util 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_resource_variable_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import variable_scope as vs 39from tensorflow.python.training import saver 40from tensorflow.python.training.tracking import base as trackable 41from tensorflow.python.util import nest 42 43 44def _devices_match(d1, d2): 45 return device_util.canonicalize(d1) == device_util.canonicalize(d2) 46 47 48class DeviceMap(object): 49 """A mapping of replicas & logical device ids to devices.""" 50 51 @property 52 def all_devices(self): 53 """Returns a tuple of strings with all devices in this DeviceMap.""" 54 raise NotImplementedError("Required for DeviceMap implementations.") 55 56 @property 57 def devices_by_replica(self): 58 """Returns a tuple `t` where `t[replica]` is the devices for `replica`.""" 59 raise NotImplementedError("Required for DeviceMap implementations.") 60 61 @property 62 def num_logical_devices(self): 63 """Count of the number of devices each replica may be defined across.""" 64 raise NotImplementedError("Required for DeviceMap implementations.") 65 66 @property 67 def num_replicas_in_graph(self): 68 """Number of replicas defined in this graph.""" 69 raise NotImplementedError("Required for DeviceMap implementations.") 70 71 def logical_device_from_values(self, values): 72 """Returns the logical device index `values` is on.""" 73 raise NotImplementedError("Required for DeviceMap implementations.") 74 75 def logical_to_actual_devices(self, logical_device_id): 76 """Returns sequence of `num_replicas_in_graph` devices.""" 77 raise NotImplementedError("Required for DeviceMap implementations.") 78 79 def select_for_current_replica(self, values, replica_context): 80 """Select the element of `values` for the current replica.""" 81 raise NotImplementedError("Required for DeviceMap implementations.") 82 83 def replica_for_device(self, device): 84 """Return the replica id containing `device`.""" 85 raise NotImplementedError("Required for DeviceMap implementations.") 86 87 def select_for_device(self, values, device): 88 """Select the element of `values` to access from `device`.""" 89 raise NotImplementedError("Required for DeviceMap implementations.") 90 91 def is_device_in_replica(self, device, replica_id): 92 """Returns whether `device` is a member of replica `replica_id`.""" 93 raise NotImplementedError("Required for DeviceMap implementations.") 94 95 96class SingleDeviceMap(DeviceMap): 97 """A device map for 1 non-computation device. 98 99 Use `SingleDeviceMap` when the device does not correspond to some replica of 100 the computation. For computation devices, use `ReplicaDeviceMap` below (even 101 if there is only a single device in the map). 102 """ 103 104 def __init__(self, device): 105 """Initialize a `SingleDeviceMap`. 106 107 Args: 108 device: A string device. 109 """ 110 assert isinstance(device, six.string_types) 111 self._device = device_util.canonicalize(device) 112 self._devices = (self._device,) 113 114 @property 115 def all_devices(self): 116 return self._devices 117 118 @property 119 def devices_by_replica(self): 120 raise ValueError("SingleDeviceMap not indexed by replicas") 121 122 @property 123 def num_logical_devices(self): 124 return 1 125 126 @property 127 def num_replicas_in_graph(self): 128 return 1 129 130 def logical_device_from_values(self, values): 131 del values 132 return 0 133 134 def logical_to_actual_devices(self, logical_device_id): 135 assert logical_device_id == 0 136 return self._devices 137 138 def select_for_current_replica(self, values, replica_context): 139 assert len(values) == 1 140 del replica_context 141 return values[0] 142 143 def replica_for_device(self, device): 144 raise ValueError("SingleDeviceMap not indexed by replicas") 145 146 def select_for_device(self, values, device): 147 assert len(values) == 1 148 if self._device != device: 149 raise ValueError("Device %s not found in %s (current device %s)" % 150 (device, self._devices, device_util.current())) 151 return values[0] 152 153 def is_device_in_replica(self, device, replica_id): 154 raise ValueError("SingleDeviceMap not indexed by replicas") 155 156 def __repr__(self): 157 return "%s(%r)" % (self.__class__.__name__, self._device) 158 159 160class ReplicaDeviceMap(DeviceMap): 161 """A device map for 1 device per replica.""" 162 163 def __init__(self, devices): 164 """Initialize a `ReplicaDeviceMap`. 165 166 Args: 167 devices: `devices[i]` is the string device for replica `i`. 168 """ 169 self._devices = tuple(device_util.canonicalize(d) for d in devices) 170 if len(set(self._devices)) != len(self._devices): 171 raise ValueError("Duplicate devices in %s, after canonicalization: %s" % 172 (devices, self._devices)) 173 self._device_to_replica = {d: r for r, d in enumerate(self._devices)} 174 175 @property 176 def all_devices(self): 177 return self._devices 178 179 @property 180 def devices_by_replica(self): 181 return ((d,) for d in self._devices) 182 183 @property 184 def num_logical_devices(self): 185 return 1 186 187 @property 188 def num_replicas_in_graph(self): 189 return len(self._devices) 190 191 def logical_device_from_values(self, values): 192 del values 193 return 0 194 195 def logical_to_actual_devices(self, logical_device_id): 196 assert logical_device_id == 0 197 return self._devices 198 199 def select_for_current_replica(self, values, replica_context): 200 assert len(values) == len(self._devices) 201 replica_id = replica_context.replica_id_in_sync_group 202 if not isinstance(replica_id, int): 203 replica_id = tensor_util.constant_value(replica_id) 204 return values[replica_id] 205 206 def replica_for_device(self, device): 207 return self._device_to_replica.get(device) 208 209 def select_for_device(self, values, device): 210 assert len(values) == len(self._devices) 211 replica_id = self._device_to_replica.get(device) 212 if replica_id is None: 213 raise ValueError("Device %s not found in %s (current device %s)" % 214 (device, self._devices, device_util.current())) 215 return values[replica_id] 216 217 def is_device_in_replica(self, device, replica_id): 218 return _devices_match(device, self._devices[replica_id]) 219 220 def __str__(self): 221 return "[%s]" % (", ".join(self._devices)) 222 223 def __repr__(self): 224 return "%s([%s])" % (self.__class__.__name__, 225 ", ".join(repr(d) for d in self._devices)) 226 227 228LogicalDeviceSpec = collections.namedtuple( 229 "LogicalDeviceSpec", ("device_map", "logical_device")) 230 231 232class DistributedValues(object): 233 """Holds a map from device to values. Either PerReplica or Mirrored.""" 234 235 def __init__(self, device_map, values, logical_device=None): 236 assert isinstance(device_map, DeviceMap) 237 self._device_map = device_map 238 self._values = tuple(values) 239 if logical_device is None: 240 logical_device = device_map.logical_device_from_values(self._values) 241 self._logical_device = logical_device 242 243 # TODO(josh11b): Split this into two functions, one with device, one without. 244 def get(self, device=None): 245 """Returns the value for the current device or raises a ValueError.""" 246 if device is None: 247 replica_context = distribution_strategy_context.get_replica_context() 248 if replica_context: 249 return self._device_map.select_for_current_replica( 250 self._values, replica_context) 251 else: 252 device = distribute_lib.get_update_device() 253 if device is None: 254 return self._get_cross_replica() 255 device = device_util.canonicalize(device) 256 return self._device_map.select_for_device(self._values, device) 257 258 @property 259 def primary(self): 260 """Returns a representative component.""" 261 return self._values[0] 262 263 @property 264 def devices(self): 265 return self._device_map.logical_to_actual_devices(self._logical_device) 266 267 @property 268 def logical_device(self): 269 return self._logical_device 270 271 @property 272 def device_map(self): 273 return self._device_map 274 275 # TODO(josh11b): Replace experimental_local_results with this? 276 @property 277 def values(self): 278 return self._values 279 280 @property 281 def is_tensor_like(self): 282 for v in self._values: 283 if not tensor_util.is_tensor(v): 284 return False 285 return True 286 287 def __str__(self): 288 devices = self.devices 289 assert len(self._values) == len(devices) 290 debug_str = ",\n".join(" %d %s: %s" % (i, devices[i], self._values[i]) 291 for i in range(len(devices))) 292 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 293 294 def __repr__(self): 295 devices = self.devices 296 assert len(self._values) == len(devices) 297 debug_repr = ",\n".join(" %d %s: %r" % (i, devices[i], self._values[i]) 298 for i in range(len(devices))) 299 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 300 301 302# NOTE(josh11b,apassos): It would be great if we could inspect the values this was 303# initialized with and use that to generate the overloaded operators here. 304# Unfortunately, Python's rules for special methods don't allow this, see 305# https://docs.python.org/3/reference/datamodel.html#special-method-names 306# "if a class defines a method named __getitem__(), and x is an instance of 307# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)." 308# In particular, these special methods don't go through __getattr__, and 309# it will only use those methods if they are defined in the class, not the 310# object. 311class DistributedDelegate(DistributedValues): 312 """A map from device to values; acts as the same type as the values.""" 313 314 def __getattr__(self, name): 315 # TODO(priyag): This needs to be made robust against pitfalls from mix use 316 # __getattr__ and @property. See b/120402273. 317 return getattr(self.get(), name) 318 319 # pylint: disable=multiple-statements 320 def __add__(self, o): return self.get() + o 321 def __radd__(self, o): return o + self.get() 322 def __sub__(self, o): return self.get() - o 323 def __rsub__(self, o): return o - self.get() 324 def __mul__(self, o): return self.get() * o 325 def __rmul__(self, o): return o * self.get() 326 def __truediv__(self, o): return self.get() / o 327 def __rtruediv__(self, o): return o / self.get() 328 329 def __floordiv__(self, o): 330 return self.get() // o 331 332 def __rfloordiv__(self, o): return o // self.get() 333 def __mod__(self, o): return self.get() % o 334 def __rmod__(self, o): return o % self.get() 335 def __lt__(self, o): return self.get() < o 336 def __le__(self, o): return self.get() <= o 337 def __gt__(self, o): return self.get() > o 338 def __ge__(self, o): return self.get() >= o 339 def __and__(self, o): return self.get() & o 340 def __rand__(self, o): return o & self.get() 341 def __or__(self, o): return self.get() | o 342 def __ror__(self, o): return o | self.get() 343 def __xor__(self, o): return self.get() ^ o 344 def __rxor__(self, o): return o ^ self.get() 345 def __getitem__(self, o): return self.get()[o] 346 def __pow__(self, o, modulo=None): return pow(self.get(), o, modulo) 347 def __rpow__(self, o): return pow(o, self.get()) 348 def __invert__(self): return ~self.get() 349 def __neg__(self): return -self.get() 350 def __abs__(self): return abs(self.get()) 351 352 def __div__(self, o): 353 try: 354 return self.get().__div__(o) 355 except AttributeError: 356 # See https://docs.python.org/3/library/constants.html#NotImplemented 357 return NotImplemented 358 359 def __rdiv__(self, o): 360 try: 361 return self.get().__rdiv__(o) 362 except AttributeError: 363 # See https://docs.python.org/3/library/constants.html#NotImplemented 364 return NotImplemented 365 366 def __matmul__(self, o): 367 try: 368 return self.get().__matmul__(o) 369 except AttributeError: 370 # See https://docs.python.org/3/library/constants.html#NotImplemented 371 return NotImplemented 372 373 def __rmatmul__(self, o): 374 try: 375 return self.get().__rmatmul__(o) 376 except AttributeError: 377 # See https://docs.python.org/3/library/constants.html#NotImplemented 378 return NotImplemented 379 380 # TODO(josh11b): Even more operator overloads. 381 382 383class PerReplica(DistributedValues): 384 """Holds a map from device to unsynchronized values.""" 385 pass 386 387 388# Note that unlike PerReplica, Mirrored values inherit from 389# DistributedDelegate and so can be used directly in cross-replica mode. 390class Mirrored(DistributedDelegate): 391 """Holds a map from device to values which are kept in sync.""" 392 393 def _get_cross_replica(self): 394 device = device_util.canonicalize(device_util.current()) 395 replica_id = self._device_map.replica_for_device(device) 396 if replica_id is None: 397 return self.primary 398 return self._values[replica_id] 399 400 def _as_graph_element(self): 401 obj = self.get() 402 conv_fn = getattr(obj, "_as_graph_element", None) 403 if conv_fn and callable(conv_fn): 404 return conv_fn() 405 return obj 406 407 408def _assign_on_device(device, variable, tensor): 409 with ops.device(device): 410 return variable.assign(array_ops.identity(tensor)) 411 412 413def _assert_strategy(strategy): 414 if not distribution_strategy_context.has_strategy(): 415 raise RuntimeError( 416 'Need to be inside "with strategy.scope()" for %s' % 417 (strategy,)) 418 current_strategy = distribution_strategy_context.get_strategy() 419 if current_strategy is not strategy: 420 raise RuntimeError( 421 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 422 (current_strategy, strategy)) 423 424 425@contextlib.contextmanager 426def _enter_or_assert_strategy(strategy): 427 if not distribution_strategy_context.has_strategy(): 428 with strategy.scope(): 429 yield 430 else: 431 _assert_strategy(strategy) 432 yield 433 434 435DistributedVarOp = collections.namedtuple( 436 "DistributedVarOp", ["name", "graph", "type"]) 437 438 439class DistributedVariable(DistributedDelegate): 440 """Holds a map from device to variables.""" 441 # TODO(josh11b): Support changing the set of variables if e.g. if new 442 # devices are joining or a device is to leave. 443 444 def __init__(self, strategy, device_map, values, logical_device=None): 445 self._distribute_strategy = strategy 446 super(DistributedVariable, self).__init__( 447 device_map, values, logical_device=logical_device) 448 self._common_name = self.primary.name.split(":")[0] 449 # Use a weakref to make it easy to map from the contained values 450 # to the container without introducing a reference cycle. 451 for v in values: 452 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 453 # tf.keras keeps track of variables initialized using this attribute. When 454 # tf.keras gets the default session, it initializes all uninitialized vars. 455 # We need to make _keras_initialized a member of DistributedVariable because 456 # without this it will use `__getattr__` which will delegate to a component 457 # variable. 458 self._keras_initialized = False 459 # Typically, a `DistributedVariable`'s initializer is composed of the 460 # initializers of the components variables. However, in some cases, such as 461 # when restoring from a checkpoint, we may set the _initializer_op 462 # property on the entire `DistributedVariable`. 463 self._initializer_op = None 464 465 def is_initialized(self, name=None): 466 """Identifies if all the component variables are initialized. 467 468 Args: 469 name: Name of the final `logical_and` op. 470 471 Returns: 472 The op that evaluates to True or False depending on if all the 473 component variables are initialized. 474 """ 475 result = self.primary.is_initialized() 476 # We iterate through the list of values except the last one to allow us to 477 # name the final `logical_and` op the same name that is passed by the user 478 # to the `is_initialized` op. For distributed variables, the 479 # `is_initialized` op is a `logical_and` op. 480 for v in self._values[1:-1]: 481 result = math_ops.logical_and(result, v.is_initialized()) 482 result = math_ops.logical_and(result, self._values[-1].is_initialized(), 483 name=name) 484 return result 485 486 @property 487 def initializer(self): 488 if self._initializer_op: 489 init_op = self._initializer_op 490 else: 491 # return grouped ops of all the var initializations of component values of 492 # the mirrored variable 493 init_op = control_flow_ops.group(tuple( 494 v.initializer for v in self._values)) 495 return init_op 496 497 def _get_closest(self): 498 """Return member in the same replica if possible, else the primary.""" 499 replica_context = distribution_strategy_context.get_replica_context() 500 if replica_context: 501 return self._device_map.select_for_current_replica( 502 self._values, replica_context) 503 device = distribute_lib.get_update_device() 504 if device is None: 505 device = device_util.canonicalize(device_util.current()) 506 replica_id = self._device_map.replica_for_device(device) 507 if replica_id is None: 508 return self.primary 509 return self._values[replica_id] 510 511 def initialized_value(self): 512 return self._get_closest().initialized_value() 513 514 @property 515 def initial_value(self): 516 return self._get_closest().initial_value 517 518 @property 519 def graph(self): 520 return self.primary.graph 521 522 @property 523 def _shared_name(self): 524 return self._common_name 525 526 @property 527 def _unique_id(self): 528 return self.primary._unique_id # pylint: disable=protected-access 529 530 @property 531 def _graph_key(self): 532 """Lets Optimizers know which graph this variable is from.""" 533 return self.primary._graph_key # pylint: disable=protected-access 534 535 @property 536 def name(self): 537 return self.primary.name 538 539 @property 540 def dtype(self): 541 return self.primary.dtype 542 543 @property 544 def shape(self): 545 return self.primary.shape 546 547 @property 548 def trainable(self): 549 return self.primary.trainable 550 551 @property 552 def distribute_strategy(self): 553 return self._distribute_strategy 554 555 def get_shape(self): 556 return self.primary.get_shape() 557 558 def to_proto(self, export_scope=None): 559 return self.primary.to_proto(export_scope=export_scope) 560 561 @property 562 def op(self): 563 # We want cross-replica code that does some var.op.X calls 564 # to work (even if the current device isn't in self.devices), but 565 # other uses of var.op in a cross-replica context to fail. 566 if distribution_strategy_context.in_cross_replica_context(): 567 return DistributedVarOp(self.primary.op.name, 568 self.primary.op.graph, 569 self.primary.op.type) 570 return self.get().op 571 572 @property 573 def _in_graph_mode(self): 574 return self.primary._in_graph_mode # pylint: disable=protected-access 575 576 def read_value(self): 577 return self._distribute_strategy.extended.read_var(self) 578 579 def value(self): 580 return self._get_closest().value() 581 582 def _should_act_as_resource_variable(self): 583 """Pass resource_variable_ops.is_resource_variable check.""" 584 pass 585 586 587ops.register_dense_tensor_like_type(DistributedVariable) 588 589 590def _validate_colocate_extended(v, extended): 591 variable_strategy = v._distribute_strategy # pylint: disable=protected-access 592 if variable_strategy.extended is not extended: 593 raise ValueError( 594 "`colocate_vars_with` must only be passed a variable created in this " 595 "tf.distribute.Strategy.scope(), not %s created in scope: %s" % 596 (v, variable_strategy)) 597 598 599def validate_colocate_distributed_variable(v, extended): 600 if not isinstance(v, DistributedVariable): 601 raise ValueError( 602 "`colocate_vars_with` must only be passed a variable created in this " 603 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 604 _validate_colocate_extended(v, extended) 605 606 607def validate_colocate_tpu_variable(v, extended): 608 if not isinstance(v, TPUMirroredVariable): 609 raise ValueError( 610 "`colocate_vars_with` must only be passed a variable created in this " 611 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 612 _validate_colocate_extended(v, extended) 613 614 615def validate_colocate(v, extended): 616 if not hasattr(v, "_distribute_strategy"): 617 raise ValueError( 618 "`colocate_vars_with` must only be passed a variable created in this " 619 "tf.distribute.Strategy.scope(), not: %r" % (v,)) 620 _validate_colocate_extended(v, extended) 621 622 623def _apply_aggregation(strategy, value, aggregation, destinations): 624 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 625 return strategy.extended.broadcast_to( 626 strategy.experimental_local_results(value)[0], 627 destinations=destinations) 628 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 629 return strategy.extended.reduce_to(reduce_op, value, destinations) 630 631 632class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): 633 """Class for defining how to restore a MirroredVariable.""" 634 635 def __init__(self, mirrored_variable, primary_variable, name): 636 self._mirrored_variable = mirrored_variable 637 super(_MirroredSaveable, self).__init__(primary_variable, "", name) 638 639 def restore(self, restored_tensors, restored_shapes): 640 """Restore the same value into all variables.""" 641 tensor, = restored_tensors 642 return control_flow_ops.group(tuple( 643 _assign_on_device(v.device, v, tensor) 644 for v in self._mirrored_variable.values)) 645 646 647class MirroredVariable(DistributedVariable, Mirrored, 648 trackable.Trackable): 649 """Holds a map from device to variables whose values are kept in sync.""" 650 651 def __init__( 652 self, strategy, device_map, values, aggregation, logical_device=None): 653 super(MirroredVariable, self).__init__( 654 strategy, device_map, values, logical_device=logical_device) 655 self._aggregation = aggregation 656 657 # The arguments to update() are automatically unwrapped so the update() 658 # function would normally see regular variables, not MirroredVariables. 659 # However, the update function can still operate on wrapped MirroredVariables 660 # through object members, captured arguments, etc. This is more likely in an 661 # update_non_slot() function (like OptimizerV2._finish), which can 662 # update several non-slot variables in one call. 663 def _assign_func(self, *args, **kwargs): 664 with _enter_or_assert_strategy(self._distribute_strategy): 665 f = kwargs.pop("f") 666 if distribution_strategy_context.in_cross_replica_context(): 667 update_device = distribute_lib.get_update_device() 668 if update_device is not None: 669 # We are calling an assign function on the mirrored variable in an 670 # update context. 671 v = self.get(device=update_device) 672 return f(v, *args, **kwargs) 673 674 # We are calling assign on the mirrored variable in cross replica 675 # context, use `strategy.extended.update()` to update the variable. 676 return self._distribute_strategy.extended.update( 677 self, f, args=args, kwargs=kwargs) 678 else: 679 _assert_replica_context(self._distribute_strategy) 680 # We are calling an assign function on the mirrored variable in replica 681 # context. 682 # We reduce the value we want to assign/add/sub. More details about how 683 # we handle the different use cases can be found in the _reduce method. 684 # We call the function on each of the mirrored variables with the 685 # reduced value. 686 if self._aggregation == vs.VariableAggregation.NONE: 687 raise ValueError("You must specify an aggregation method to update a " 688 "MirroredVariable in Replica Context.") 689 690 def merge_fn(strategy, value, *other_args, **other_kwargs): 691 v = _apply_aggregation(strategy, value, self._aggregation, self) 692 return strategy.extended.update( 693 self, f, args=(v,) + other_args, kwargs=other_kwargs) 694 695 return distribution_strategy_context.get_replica_context().merge_call( 696 merge_fn, args=args, kwargs=kwargs) 697 698 def assign_sub(self, *args, **kwargs): 699 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 700 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 701 702 def assign_add(self, *args, **kwargs): 703 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 704 return self._assign_func(f=assign_add_fn, *args, **kwargs) 705 706 def assign(self, *args, **kwargs): 707 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 708 return self._assign_func(f=assign_fn, *args, **kwargs) 709 710 @property 711 def aggregation(self): 712 return self._aggregation 713 714 def _get_cross_replica(self): 715 device = device_util.canonicalize(device_util.current()) 716 replica_id = self._device_map.replica_for_device(device) 717 if replica_id is None: 718 return array_ops.identity(self.primary) 719 return array_ops.identity(self._values[replica_id]) 720 721 def _as_graph_element(self): 722 # pylint: disable=protected-access 723 if distribution_strategy_context.in_cross_replica_context(): 724 return self.primary._as_graph_element() 725 return self.get()._as_graph_element() 726 727 def _gather_saveables_for_checkpoint(self): 728 """Overrides Trackable method. 729 730 This allows both name-based and object-based save and restore of 731 MirroredVariables. 732 733 Returns: 734 A dictionary mapping attribute names to `SaveableObject` factories. 735 """ 736 def _saveable_factory(name=self._common_name): 737 return _MirroredSaveable(self, self.primary, name) 738 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 739 740 741# Register a conversion function which reads the value of the variable, 742# allowing instances of the class to be used as tensors. 743def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): 744 # Try to avoid assignments to and other mutations of MirroredVariable 745 # state except through a DistributionStrategy.extended.update() call. 746 assert not as_ref 747 return ops.internal_convert_to_tensor( 748 var.get(), dtype=dtype, name=name, as_ref=as_ref) 749 750 751ops.register_tensor_conversion_function(MirroredVariable, 752 _tensor_conversion_mirrored) 753 754 755def _enclosing_tpu_context(): 756 # pylint: disable=protected-access 757 tpu_context = ops.get_default_graph()._get_control_flow_context() 758 # pylint: enable=protected-access 759 while tpu_context is not None and not isinstance( 760 tpu_context, control_flow_ops.XLAControlFlowContext): 761 tpu_context = tpu_context.outer_context 762 return tpu_context 763 764 765# TODO(jhseu): Deduplicate code. We copy code because we don't want to 766# inherit from DistributedDelegate. DistributedDelegate will not work in a 767# tpu.replicate() because it assumes that you're in a device context where you 768# can operate on a single version of the variable, but a tpu.replicate() 769# operates on all variables and is replicated during a rewrite pass. 770class TPUMirroredVariable(trackable.Trackable): 771 """Holds a map from device to TPU variables whose values are kept in sync.""" 772 773 def __init__( 774 self, strategy, device_map, values, aggregation, logical_device=None): 775 assert isinstance(device_map, DeviceMap) 776 self._distribute_strategy = strategy 777 self._device_map = device_map 778 self._values = tuple(values) 779 if logical_device is None: 780 logical_device = device_map.logical_device_from_values(self._values) 781 self._logical_device = logical_device 782 783 # Use a weakref to make it easy to map from the contained values 784 # to the container without introducing a reference cycle. 785 for v in self._values: 786 v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access 787 self._common_name = self.primary.name.split(":")[0] 788 self._aggregation = aggregation 789 # Needed for GradientTape 790 self._trainable = self.primary.trainable 791 # Typically like `DistributedVariable`, a `TPUMirroredVariable`'s 792 # initializer is composed of the initializers of the components variables. 793 # However, in some cases, such as when restoring from a checkpoint, we may 794 # set the _initializer_op property on the entire `TPUMirroredVariable`. 795 self._initializer_op = None 796 797 def _get(self, device=None): 798 """Returns the value for the current device or raises a ValueError.""" 799 if device is None: 800 replica_context = distribution_strategy_context.get_replica_context() 801 if replica_context: 802 return self._device_map.select_for_current_replica( 803 self._values, replica_context) 804 else: 805 device = distribute_lib.get_update_device() 806 if device is None: 807 return self._get_cross_replica() 808 device = device_util.canonicalize(device) 809 return self._device_map.select_for_device(self._values, device) 810 811 @property 812 def primary(self): 813 """Returns a representative component.""" 814 return self._values[0] 815 816 @property 817 def devices(self): 818 return self._device_map.logical_to_actual_devices(self._logical_device) 819 820 @property 821 def logical_device(self): 822 return self._logical_device 823 824 @property 825 def device_map(self): 826 return self._device_map 827 828 # TODO(josh11b): Replace experimental_local_results with this? 829 @property 830 def values(self): 831 return self._values 832 833 @property 834 def distribute_strategy(self): 835 return self._distribute_strategy 836 837 # pylint: disable=multiple-statements 838 def __add__(self, o): return self.read_value() + o 839 def __radd__(self, o): return o + self.read_value() 840 def __sub__(self, o): return self.read_value() - o 841 def __rsub__(self, o): return o - self.read_value() 842 def __mul__(self, o): return self.read_value() * o 843 def __rmul__(self, o): return o * self.read_value() 844 def __truediv__(self, o): return self.read_value() / o 845 def __rtruediv__(self, o): return o / self.read_value() 846 def __floordiv__(self, o): return self.read_value() // o 847 def __rfloordiv__(self, o): return o // self.read_value() 848 def __mod__(self, o): return self.read_value() % o 849 def __rmod__(self, o): return o % self.read_value() 850 def __lt__(self, o): return self.read_value() < o 851 def __le__(self, o): return self.read_value() <= o 852 def __gt__(self, o): return self.read_value() > o 853 def __ge__(self, o): return self.read_value() >= o 854 def __and__(self, o): return self.read_value() & o 855 def __rand__(self, o): return o & self.read_value() 856 def __or__(self, o): return self.read_value() | o 857 def __ror__(self, o): return o | self.read_value() 858 def __xor__(self, o): return self.read_value() ^ o 859 def __rxor__(self, o): return o ^ self.read_value() 860 def __getitem__(self, o): return self.read_value()[o] 861 def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo) 862 def __rpow__(self, o): return pow(o, self.read_value()) 863 def __invert__(self): return ~self.read_value() 864 def __neg__(self): return -self.read_value() 865 def __abs__(self): return abs(self.read_value()) 866 867 def __div__(self, o): 868 try: 869 return self.read_value().__div__(o) 870 except AttributeError: 871 # See https://docs.python.org/3/library/constants.html#NotImplemented 872 return NotImplemented 873 874 def __rdiv__(self, o): 875 try: 876 return self.read_value().__rdiv__(o) 877 except AttributeError: 878 # See https://docs.python.org/3/library/constants.html#NotImplemented 879 return NotImplemented 880 881 def __matmul__(self, o): 882 try: 883 return self.read_value().__matmul__(o) 884 except AttributeError: 885 # See https://docs.python.org/3/library/constants.html#NotImplemented 886 return NotImplemented 887 888 def __rmatmul__(self, o): 889 try: 890 return self.read_value().__rmatmul__(o) 891 except AttributeError: 892 # See https://docs.python.org/3/library/constants.html#NotImplemented 893 return NotImplemented 894 895 def __str__(self): 896 devices = self.devices 897 debug_str = ",\n".join(" %d %s: %s" % (i, devices[i], self._values[i]) 898 for i in range(len(devices))) 899 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 900 901 def __repr__(self): 902 devices = self.devices 903 debug_repr = ",\n".join(" %d %s: %r" % (i, devices[i], self._values[i]) 904 for i in range(len(devices))) 905 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 906 907 @property 908 def handle(self): 909 # If we're in a tpu.rewrite(), return the replicated handle. 910 tpu_context = _enclosing_tpu_context() 911 if tpu_context is not None: 912 return tpu_context.get_replicated_var_handle( 913 self._common_name, self._values) 914 915 device = distribute_lib.get_update_device() 916 if device is None: 917 return self.primary.handle 918 return self._get(device=device).handle 919 920 @property 921 def device(self): 922 return self.handle.device 923 924 def eval(self, session=None): 925 return self.primary.eval(session) 926 927 # The arguments to update() are automatically unwrapped so the update() 928 # function would normally see regular variables, not MirroredVariables. 929 # However, the update function can still operate on wrapped MirroredVariables 930 # through object members, captured arguments, etc. This is more likely in an 931 # update_non_slot() function (like OptimizerV2._finish), which can 932 # update several non-slot variables in one call. 933 def _assign_func(self, *args, **kwargs): 934 with _enter_or_assert_strategy(self._distribute_strategy): 935 f = kwargs.pop("f") 936 if distribution_strategy_context.in_cross_replica_context(): 937 if _enclosing_tpu_context() is not None: 938 return self._distribute_strategy.extended.update( 939 self, f, args=args, kwargs=kwargs) 940 941 update_device = distribute_lib.get_update_device() 942 # We are calling update on the mirrored variable in cross replica 943 # context. 944 if update_device is not None: 945 # We are calling an assign function on the mirrored variable in cross 946 # replica context. 947 v = self._get(device=update_device) 948 return f(v, *args, **kwargs) 949 950 return self._distribute_strategy.extended.update( 951 self, f, args=args, kwargs=kwargs) 952 else: 953 _assert_replica_context(self._distribute_strategy) 954 # We are calling an assign function on the mirrored variable in replica 955 # context. 956 # We reduce the value we want to assign/add/sub. More details about how 957 # we handle the different use cases can be found in the _reduce method. 958 # We call the function on each of the mirrored variables with the 959 # reduced value. 960 if self._aggregation == vs.VariableAggregation.NONE: 961 raise ValueError("You must specify an aggregation method to update a " 962 "TPUMirroredVariable in Replica Context.") 963 964 def merge_fn(strategy, value, *other_args, **other_kwargs): 965 v = _apply_aggregation(strategy, value, self._aggregation, self) 966 return strategy.extended.update( 967 self, f, args=(v,) + other_args, kwargs=other_kwargs) 968 969 return distribution_strategy_context.get_replica_context().merge_call( 970 merge_fn, args=args, kwargs=kwargs) 971 972 @contextlib.contextmanager 973 def _handle_graph(self, handle): 974 # Note: might have an eager tensor but not be executing eagerly when 975 # building functions. 976 if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor) 977 or ops.has_default_graph()): 978 yield 979 else: 980 with handle.graph.as_default(): 981 yield 982 983 @property 984 def trainable(self): 985 return self._trainable 986 987 def _read_variable_op(self, parent_op=None): 988 if self.trainable: 989 tape.variable_accessed(self) 990 if parent_op is not None: 991 with ops.control_dependencies([parent_op]): 992 return gen_resource_variable_ops.read_variable_op( 993 self.handle, self.dtype) 994 995 return gen_resource_variable_ops.read_variable_op( 996 self.handle, self.dtype) 997 998 def read_value(self): 999 return self._read_variable_op() 1000 1001 def assign_sub(self, *args, **kwargs): 1002 def assign_sub_fn(var, delta, *ar, **kw): 1003 del ar 1004 name = kw.pop("name", None) 1005 read_value = kw.pop("read_value", True) 1006 with self._handle_graph(var.handle): 1007 op = gen_resource_variable_ops.assign_sub_variable_op( 1008 var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), 1009 name=name) 1010 if read_value: 1011 return self._read_variable_op(parent_op=op) 1012 return op 1013 1014 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 1015 1016 def assign_add(self, *args, **kwargs): 1017 def assign_add_fn(var, delta, *ar, **kw): 1018 del ar 1019 name = kw.pop("name", None) 1020 read_value = kw.pop("read_value", True) 1021 with self._handle_graph(var.handle): 1022 op = gen_resource_variable_ops.assign_add_variable_op( 1023 var.handle, ops.convert_to_tensor(delta, dtype=self.dtype), 1024 name=name) 1025 if read_value: 1026 return self._read_variable_op(parent_op=op) 1027 return op 1028 1029 return self._assign_func(f=assign_add_fn, *args, **kwargs) 1030 1031 def assign(self, *args, **kwargs): 1032 def assign_fn(var, value, *ar, **kw): 1033 del ar 1034 name = kw.pop("name", None) 1035 read_value = kw.pop("read_value", True) 1036 with self._handle_graph(var.handle): 1037 op = gen_resource_variable_ops.assign_variable_op( 1038 var.handle, ops.convert_to_tensor(value, dtype=self.dtype), 1039 name=name) 1040 if read_value: 1041 return self._read_variable_op(parent_op=op) 1042 return op 1043 1044 return self._assign_func(f=assign_fn, *args, **kwargs) 1045 1046 @property 1047 def aggregation(self): 1048 return self._aggregation 1049 1050 @property 1051 def constraint(self): 1052 return None 1053 1054 @property 1055 def initializer(self): 1056 if self._initializer_op: 1057 init_op = self._initializer_op 1058 else: 1059 init_op = control_flow_ops.group(tuple( 1060 v.initializer for v in self._values)) 1061 return init_op 1062 1063 @property 1064 def graph(self): 1065 return self.primary.graph 1066 1067 @property 1068 def _shared_name(self): 1069 return self._common_name 1070 1071 @property 1072 def _unique_id(self): 1073 return self.primary._unique_id # pylint: disable=protected-access 1074 1075 @property 1076 def name(self): 1077 return self.primary.name 1078 1079 @property 1080 def dtype(self): 1081 return self.primary.dtype 1082 1083 @property 1084 def shape(self): 1085 return self.primary.shape 1086 1087 def get_shape(self): 1088 return self.primary.get_shape() 1089 1090 def to_proto(self, export_scope=None): 1091 return self.primary.to_proto(export_scope=export_scope) 1092 1093 def _get_cross_replica(self): 1094 device = device_util.canonicalize(device_util.current()) 1095 replica = self._device_map.replica_for_device(device) 1096 if replica is None: 1097 return self.primary 1098 return self._values[replica] 1099 1100 def _as_graph_element(self): 1101 # pylint: disable=protected-access 1102 if _enclosing_tpu_context() is None: 1103 if distribution_strategy_context.in_cross_replica_context(): 1104 return self.primary._as_graph_element() 1105 return self._get()._as_graph_element() 1106 return None 1107 1108 def _gather_saveables_for_checkpoint(self): 1109 """Overrides Trackable method. 1110 1111 This allows both name-based and object-based save and restore of 1112 MirroredVariables. 1113 1114 Returns: 1115 A dictionary mapping attribute names to `SaveableObject` factories. 1116 """ 1117 def _saveable_factory(name=self._common_name): 1118 return _MirroredSaveable(self, self.primary, name) 1119 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1120 1121 def _should_act_as_resource_variable(self): 1122 """Pass resource_variable_ops.is_resource_variable check.""" 1123 pass 1124 1125 # Needed to pass ResourceVariable checks. 1126 @property 1127 def op(self): 1128 return self.primary.op 1129 1130 # pylint: disable=protected-access 1131 @property 1132 def _save_slice_info(self): 1133 return self.primary._save_slice_info 1134 1135 def _get_save_slice_info(self): 1136 return self.primary._get_save_slice_info() 1137 1138 def _set_save_slice_info(self, save_slice_info): 1139 return self.primary._set_save_slice_info(save_slice_info) 1140 # pylint: enable=protected-access 1141 1142 @property 1143 def _in_graph_mode(self): 1144 return self.primary._in_graph_mode # pylint: disable=protected-access 1145 1146 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1147 """Converts a variable to a tensor.""" 1148 # pylint: disable=protected-access 1149 if _enclosing_tpu_context() is None: 1150 return self._get()._dense_var_to_tensor(dtype, name, as_ref) 1151 # pylint: enable=protected-access 1152 if dtype is not None and dtype != self.dtype: 1153 return math_ops.cast(self.read_value(), dtype) 1154 if as_ref: 1155 return self.handle 1156 else: 1157 return self.read_value() 1158 1159 def is_initialized(self, name=None): 1160 """Identifies if all the component variables are initialized. 1161 1162 Args: 1163 name: Name of the final `logical_and` op. 1164 1165 Returns: 1166 The op that evaluates to True or False depending on if all the 1167 component variables are initialized. 1168 """ 1169 # TODO(jhseu): Do we need TPU context implementation? 1170 1171 result = self.primary.is_initialized() 1172 # We iterate through the list of values except the last one to allow us to 1173 # name the final `logical_and` op the same name that is passed by the user 1174 # to the `is_initialized` op. For distributed variables, the 1175 # `is_initialized` op is a `logical_and` op. 1176 for v in self._values[1:-1]: 1177 result = math_ops.logical_and(result, v.is_initialized()) 1178 result = math_ops.logical_and(result, self._values[-1].is_initialized(), 1179 name=name) 1180 return result 1181 1182 1183# Register a conversion function which reads the value of the variable, 1184# allowing instances of the class to be used as tensors. 1185def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False): 1186 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1187 1188 1189ops.register_tensor_conversion_function(TPUMirroredVariable, 1190 _tensor_conversion_tpu_mirrored) 1191ops.register_dense_tensor_like_type(TPUMirroredVariable) 1192 1193 1194class _SyncOnReadSaveable(saver.BaseSaverBuilder.SaveableObject): 1195 """Class for defining how to restore a SyncOnReadVariable.""" 1196 1197 def __init__(self, sync_on_read_variable, name): 1198 self._sync_on_read_variable = sync_on_read_variable 1199 # We use a callable so that we don't have to evaluate this expression 1200 # in the case where we are trying to restore instead of save. 1201 def tensor(): 1202 strategy = sync_on_read_variable._distribute_strategy # pylint: disable=protected-access 1203 return strategy.extended.read_var(sync_on_read_variable) 1204 1205 spec = saver.BaseSaverBuilder.SaveSpec( 1206 tensor=tensor, 1207 slice_spec="", 1208 name=name, 1209 dtype=sync_on_read_variable.dtype) 1210 super(_SyncOnReadSaveable, self).__init__(tensor, [spec], name) 1211 1212 def restore(self, restored_tensors, restored_shapes): 1213 """Restore the same value into all variables.""" 1214 tensor, = restored_tensors 1215 return self._sync_on_read_variable.assign(tensor) 1216 1217 1218def _assert_replica_context(strategy): 1219 replica_context = distribution_strategy_context.get_replica_context() 1220 if not replica_context: 1221 raise RuntimeError( 1222 "Replica-local variables may only be assigned in a replica context.") 1223 if replica_context.strategy is not strategy: 1224 raise RuntimeError( 1225 "Replica-local variables may only be assigned in a replica context.") 1226 1227 1228class SyncOnReadVariable(DistributedVariable, PerReplica, trackable.Trackable): 1229 """Holds a map from device to variables whose values are reduced on save.""" 1230 1231 def __init__( 1232 self, strategy, device_map, values, aggregation, logical_device=None): 1233 self._aggregation = aggregation 1234 super(SyncOnReadVariable, self).__init__( 1235 strategy, device_map, values, logical_device=logical_device) 1236 1237 def assign_sub(self, *args, **kwargs): 1238 _assert_replica_context(self._distribute_strategy) 1239 return self.get().assign_sub(*args, **kwargs) 1240 1241 def assign_add(self, *args, **kwargs): 1242 _assert_replica_context(self._distribute_strategy) 1243 return self.get().assign_add(*args, **kwargs) 1244 1245 def assign(self, *args, **kwargs): 1246 if distribution_strategy_context.in_cross_replica_context(): 1247 # To preserve the sum across save and restore, we have to divide the 1248 # total across all devices when restoring a variable that was summed 1249 # when saving. 1250 tensor = args[0] 1251 if self._aggregation == vs.VariableAggregation.SUM: 1252 tensor *= 1. / len(self.devices) 1253 return control_flow_ops.group(tuple( 1254 _assign_on_device(v.device, v, tensor) for v in self._values)) 1255 else: 1256 _assert_replica_context(self._distribute_strategy) 1257 return self.get().assign(*args, **kwargs) 1258 1259 @property 1260 def aggregation(self): 1261 return self._aggregation 1262 1263 def _get_cross_replica(self): 1264 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1265 return self.primary 1266 return self._distribute_strategy.reduce( 1267 reduce_util.ReduceOp.from_variable_aggregation(self.aggregation), self) 1268 1269 def _as_graph_element(self): 1270 # pylint: disable=protected-access 1271 if distribution_strategy_context.in_cross_replica_context(): 1272 return self._get_cross_replica() 1273 return self.get()._as_graph_element() 1274 1275 def _gather_saveables_for_checkpoint(self): 1276 """Overrides Trackable method. 1277 1278 This allows both name-based and object-based save and restore of 1279 `SyncOnReadVariable`s. 1280 1281 Returns: 1282 A dictionary mapping attribute names to `SaveableObject` factories. 1283 """ 1284 def _saveable_factory(name=self._common_name): 1285 return _SyncOnReadSaveable(self, name) 1286 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1287 1288 1289# Register a conversion function for SyncOnReadVariable which allows as_ref to 1290# be true. 1291def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): 1292 return ops.internal_convert_to_tensor( 1293 var.get(), dtype=dtype, name=name, as_ref=as_ref) 1294 1295 1296ops.register_tensor_conversion_function(SyncOnReadVariable, 1297 _tensor_conversion_sync_on_read) 1298 1299 1300def regroup(device_map, values, wrap_class=PerReplica): 1301 """Makes a nest per-replica into a nest of PerReplica/Mirrored values.""" 1302 assert isinstance(device_map, DeviceMap) 1303 assert len(values) == device_map.num_replicas_in_graph 1304 v0 = values[0] 1305 1306 if isinstance(v0, list): 1307 for v in values[1:]: 1308 assert isinstance(v, list) 1309 assert len(v) == len(v0), ("len(v) == %d, len(v0) == %d, v: %s, v0: %s" % 1310 (len(v), len(v0), v, v0)) 1311 return [regroup(device_map, tuple(v[i] for v in values), wrap_class) 1312 for i in range(len(v0))] 1313 1314 if isinstance(v0, tuple): 1315 for v in values[1:]: 1316 assert isinstance(v, tuple) 1317 assert len(v) == len(v0) 1318 regrouped_tuple = tuple( 1319 regroup(device_map, tuple(v[i] for v in values), wrap_class) 1320 for i in range(len(v0))) 1321 if hasattr(v0, "_fields"): 1322 # This tuple is in fact a namedtuple! Create a new namedtuple instance 1323 # and initialize it with the regrouped values: 1324 assert hasattr(type(v0), "_make") 1325 return type(v0)._make(regrouped_tuple) 1326 else: 1327 return regrouped_tuple 1328 1329 if isinstance(v0, dict): 1330 v0keys = set(v0.keys()) 1331 for v in values[1:]: 1332 assert isinstance(v, dict), ("v[0]: %r v[i]: %r" % (v0, v)) 1333 assert set(v.keys()) == v0keys, ("v[0].keys: %s v[i].keys: %s" % 1334 (v0keys, set(v.keys()))) 1335 return {key: regroup(device_map, tuple(v[key] for v in values), wrap_class) 1336 for key in v0keys} 1337 1338 # If exactly the same object across all devices, return it unwrapped. 1339 same_id = True 1340 for v in values[1:]: 1341 if v is not v0: 1342 same_id = False 1343 break 1344 # Consider three cases where same_id is true: 1345 # * If v0 is a DistributedVariable (a MirroredVariable or 1346 # SyncOnReadVariable, and same_id means it is the same across all 1347 # devices), we want to return it. We check DistributedVariable 1348 # specifically since it can look like it has a 1349 # _distributed_container member since its members do. 1350 # * If v0 is a member of a distributed variable, in which case 1351 # hasattr(v0, "_distributed_container") is true, we want to 1352 # return the DistributedVariable that contains it using the 1353 # _distributed_container logic below. This case can trigger 1354 # same_id when there is only one device. 1355 # * In any other situation, same_id means we return v0. 1356 if same_id and (isinstance(v0, DistributedVariable) or 1357 not hasattr(v0, "_distributed_container")): 1358 return v0 1359 1360 # Detect the case where each device has a parallel component of the 1361 # same MirroredVariable (or SyncOnReadVariable). In this case we 1362 # want to return the containing MirroredVariable, after a bunch of 1363 # sanity checking. In particular, each component should have the 1364 # same container, and the devices of the variables should match the 1365 # keys of the per-replica dictionary. 1366 if hasattr(v0, "_distributed_container"): 1367 # pylint: disable=protected-access 1368 assert not isinstance(v0, MirroredVariable), ( 1369 "ids = %s, values = %s" % ([id(v) for v in values], values)) 1370 assert device_map.is_device_in_replica(v0.device, 0), ( 1371 "v0.device = %s, device_map = %s" % (v0.device, device_map)) 1372 distributed_container = v0._distributed_container() 1373 assert distributed_container is not None 1374 for r, v in enumerate(values[1:]): 1375 assert device_map.is_device_in_replica(v.device, r + 1), ( 1376 "v.device = %s, r = %d, device_map = %s" % 1377 (v.device, r + 1, device_map)) 1378 assert distributed_container is v._distributed_container() 1379 return distributed_container 1380 # pylint: enable=protected-access 1381 1382 return wrap_class(device_map, values) 1383 1384 1385def select_replica(replica_id, structured): 1386 """Specialize a nest of regular & per-replica values for one replica.""" 1387 def _get(x): 1388 return x.values[replica_id] if isinstance(x, DistributedValues) else x 1389 1390 return nest.map_structure(_get, structured) 1391 1392 1393def select_device_mirrored(device, structured): 1394 """Specialize a nest of regular & mirrored values for one device.""" 1395 def _get_mirrored(x): 1396 if isinstance(x, DistributedValues): 1397 if not isinstance(x, Mirrored): 1398 raise TypeError( 1399 "Expected value to be mirrored across replicas: %s in %s." % 1400 (x, structured)) 1401 return x.get(device) 1402 else: 1403 return x 1404 1405 return nest.map_structure(_get_mirrored, structured) 1406 1407 1408def update_regroup(extended, device_map, updates, group): 1409 """Regroup for an update, with dependencies to ensure all updates execute.""" 1410 # TODO(josh11b): Replace "Mirrored" here with a function that does the following 1411 # so we can avoid all these nest operations. 1412 regrouped = regroup(device_map, updates, Mirrored) 1413 if not group: 1414 return nest.map_structure(extended._local_results, regrouped) # pylint: disable=protected-access 1415 grouped_flat = [] 1416 for u in nest.flatten(regrouped): 1417 if isinstance(u, DistributedValues): 1418 g = extended._group(u) # pylint: disable=protected-access 1419 if u.is_tensor_like: 1420 # Make sure we run all updates. Without this, something like 1421 # session.run(extended.update(...)) may only update one replica. 1422 values = [] 1423 for d in u.devices: 1424 with ops.device(d), ops.control_dependencies([g]): 1425 values.append(array_ops.identity(u.get(d))) 1426 g = Mirrored(u.device_map, values) 1427 else: 1428 g = u 1429 grouped_flat.append(g) 1430 return nest.pack_sequence_as(regrouped, grouped_flat) 1431 1432 1433def value_container(val): 1434 """Returns the container that this per-replica `value` belongs to. 1435 1436 Args: 1437 val: A value returned by `call_for_each_replica()` or a variable 1438 created in `scope()`. 1439 1440 Returns: 1441 A container that `value` belongs to. 1442 If value does not belong to any container (including the case of 1443 container having been destroyed), returns the value itself. 1444 """ 1445 if (hasattr(val, "_distributed_container") and 1446 # DistributedVariable has _distributed_container defined 1447 # but we don't want to return it. 1448 not isinstance(val, DistributedVariable)): 1449 container = val._distributed_container() # pylint: disable=protected-access 1450 if container is not None: 1451 return container 1452 return val 1453 1454 1455# TODO(josh11b): Descend from Variable. 1456class AggregatingVariable(trackable.Trackable): 1457 """A wrapper around a variable that aggregates updates across replicas.""" 1458 1459 def __init__(self, strategy, v, aggregation): 1460 self._distribute_strategy = strategy 1461 self._v = v 1462 # NOTE: We don't use "_distributed_container" here because we don't want 1463 # to trigger that code path in regroup(). 1464 v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access 1465 self._aggregation = aggregation 1466 1467 def get(self): 1468 return self._v 1469 1470 @property 1471 def distribute_strategy(self): 1472 return self._distribute_strategy 1473 1474 def __getattr__(self, name): 1475 return getattr(self._v, name) 1476 1477 def _assign_func(self, *args, **kwargs): 1478 with _enter_or_assert_strategy(self._distribute_strategy): 1479 f = kwargs.pop("f") 1480 if distribution_strategy_context.in_cross_replica_context(): 1481 update_device = distribute_lib.get_update_device() 1482 if update_device is not None: 1483 # We are calling an assign function in an update context. 1484 return f(self._v, *args, **kwargs) 1485 1486 # We are calling an assign function in cross replica context, wrap it in 1487 # an update call. 1488 return self._distribute_strategy.extended.update( 1489 self, f, args=args, kwargs=kwargs) 1490 else: 1491 replica_context = distribution_strategy_context.get_replica_context() 1492 assert replica_context 1493 # We are calling an assign function in replica context. 1494 # We reduce the value we want to assign/add/sub. More details about how 1495 # we handle the different use cases can be found in the _reduce method. 1496 # We call the function with the reduced value. 1497 if self._aggregation == vs.VariableAggregation.NONE: 1498 raise ValueError("You must specify an aggregation method to update a " 1499 "a variable in replica context.") 1500 1501 def merge_fn(strategy, value, *other_args, **other_kwargs): 1502 v = _apply_aggregation(strategy, value, self._aggregation, self) 1503 return strategy.extended.update( 1504 self, f, args=(v,) + other_args, kwargs=other_kwargs) 1505 1506 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 1507 1508 def assign_sub(self, *args, **kwargs): 1509 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 1510 return self._assign_func(f=assign_sub_fn, *args, **kwargs) 1511 1512 def assign_add(self, *args, **kwargs): 1513 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 1514 return self._assign_func(f=assign_add_fn, *args, **kwargs) 1515 1516 def assign(self, *args, **kwargs): 1517 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 1518 return self._assign_func(f=assign_fn, *args, **kwargs) 1519 1520 @property 1521 def aggregation(self): 1522 return self._aggregation 1523 1524 @property 1525 def name(self): 1526 return self._v.name 1527 1528 @property 1529 def dtype(self): 1530 return self._v.dtype 1531 1532 # TODO(josh11b): Test saving & restoring. 1533 def _gather_saveables_for_checkpoint(self): 1534 return {trackable.VARIABLE_VALUE_KEY: self._v} 1535 1536 # pylint: disable=multiple-statements 1537 def __add__(self, o): return self._v + o 1538 def __radd__(self, o): return o + self._v 1539 def __sub__(self, o): return self._v - o 1540 def __rsub__(self, o): return o - self._v 1541 def __mul__(self, o): return self._v * o 1542 def __rmul__(self, o): return o * self._v 1543 def __truediv__(self, o): return self._v / o 1544 def __rtruediv__(self, o): return o / self._v 1545 def __floordiv__(self, o): return self._v // o 1546 def __rfloordiv__(self, o): return o // self._v 1547 def __mod__(self, o): return self._v % o 1548 def __rmod__(self, o): return o % self._v 1549 def __lt__(self, o): return self._v < o 1550 def __le__(self, o): return self._v <= o 1551 def __gt__(self, o): return self._v > o 1552 def __ge__(self, o): return self._v >= o 1553 def __and__(self, o): return self._v & o 1554 def __rand__(self, o): return o & self._v 1555 def __or__(self, o): return self._v | o 1556 def __ror__(self, o): return o | self._v 1557 def __xor__(self, o): return self._v ^ o 1558 def __rxor__(self, o): return o ^ self._v 1559 def __getitem__(self, o): return self._v[o] 1560 def __pow__(self, o, modulo=None): return pow(self._v, o, modulo) 1561 def __rpow__(self, o): return pow(o, self._v) 1562 def __invert__(self): return ~self._v 1563 def __neg__(self): return -self._v 1564 def __abs__(self): return abs(self._v) 1565 1566 def __div__(self, o): 1567 try: 1568 return self._v.__div__(o) 1569 except AttributeError: 1570 # See https://docs.python.org/3/library/constants.html#NotImplemented 1571 return NotImplemented 1572 1573 def __rdiv__(self, o): 1574 try: 1575 return self._v.__rdiv__(o) 1576 except AttributeError: 1577 # See https://docs.python.org/3/library/constants.html#NotImplemented 1578 return NotImplemented 1579 1580 def __matmul__(self, o): 1581 try: 1582 return self._v.__matmul__(o) 1583 except AttributeError: 1584 # See https://docs.python.org/3/library/constants.html#NotImplemented 1585 return NotImplemented 1586 1587 def __rmatmul__(self, o): 1588 try: 1589 return self._v.__rmatmul__(o) 1590 except AttributeError: 1591 # See https://docs.python.org/3/library/constants.html#NotImplemented 1592 return NotImplemented 1593 1594 def __str__(self): 1595 return str(self._v) 1596 1597 def __repr__(self): 1598 return repr(self._v) 1599 1600 def _should_act_as_resource_variable(self): 1601 """Pass resource_variable_ops.is_resource_variable check.""" 1602 pass 1603 1604 1605# Register a conversion function which reads the value of the variable, 1606# allowing instances of the class to be used as tensors. 1607def _tensor_conversion_aggregate(var, dtype=None, name=None, as_ref=False): 1608 return ops.internal_convert_to_tensor( 1609 var.get(), dtype=dtype, name=name, as_ref=as_ref) 1610 1611 1612ops.register_tensor_conversion_function( 1613 AggregatingVariable, _tensor_conversion_aggregate) 1614ops.register_dense_tensor_like_type(AggregatingVariable) 1615