1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A variable which packs a list of variables distributed across devices.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute import device_util 22from tensorflow.python.eager import context 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import resource_variable_ops 26 27 28class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable): 29 """A variable which packs multiple variables distributed across devices. 30 31 It's only supported when eager execution is enabled. 32 For op-by-op execution, use an unpacked handle on the current device; for 33 function execution, use the packed handle to reduce the overhead of function 34 calls. 35 """ 36 37 def __init__(self, distributed_variables=None, name=None, **unused_kwargs): 38 """Packs a list of variables which are distributed across devices. 39 40 Args: 41 distributed_variables: A list of distributed Variables to pack. 42 name: Optional name for the variable. Defaults to `'Variable'` and gets 43 uniquified automatically. 44 """ 45 if not ops.executing_eagerly_outside_functions(): 46 raise ValueError( 47 "PackedDistributedVariable should be created in eager mode.") 48 if not distributed_variables: 49 raise ValueError("Expect a non-empty list of variables to pack.") 50 for i, var in enumerate(distributed_variables): 51 if not resource_variable_ops.is_resource_variable(var): 52 raise ValueError("Expect a list of ResourceVariables to pack, " 53 "but the %d-th variable is %s" % (i, type(var))) 54 55 self._distributed_variables = distributed_variables 56 self._devices = [v.device for v in distributed_variables] 57 with ops.init_scope(): 58 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 59 handle = ops.pack_eager_tensors( 60 [var.handle for var in distributed_variables]) 61 handle_name = ops.name_from_scope_name(name) 62 unique_id = "%s_%d" % (handle_name, ops.uid()) 63 super(PackedDistributedVariable, self).__init__( 64 trainable=distributed_variables[0].trainable, 65 shape=distributed_variables[0].shape, 66 dtype=distributed_variables[0].dtype, 67 handle=handle, 68 synchronization=distributed_variables[0].synchronization, 69 constraint=distributed_variables[0].constraint, 70 aggregation=distributed_variables[0].aggregation, 71 distribute_strategy=distributed_variables[0]._distribute_strategy, # pylint: disable=protected-access 72 name=name, 73 unique_id=unique_id, 74 handle_name=handle_name, 75 graph_element=None, 76 initial_value=None, 77 initializer_op=None, 78 is_initialized_op=None, 79 cached_value=None, 80 caching_device=None, 81 is_distributed_variables=True) 82 83 @property 84 def devices(self): 85 return self._devices 86 87 def on_device(self, device): 88 return PackedVarAndDevice(self, device) 89 90 def get_var_on_device(self, device): 91 for i, d in enumerate(self._devices): 92 if d == device: 93 return self._distributed_variables[i] 94 raise ValueError("Device %s is not found" % device) 95 96 def get_var_on_current_device(self): 97 current_device = device_util.canonicalize(device_util.current()) 98 return self.get_var_on_device(current_device) 99 100 def initial_value(self, device): 101 """Returns the Tensor used as the initial value for the variable.""" 102 return self.get_var_on_device(device).initial_value 103 104 @property 105 def handle(self): 106 if context.executing_eagerly(): 107 return self.get_var_on_current_device().handle 108 else: 109 return self._handle 110 111 @property 112 def packed_handle(self): 113 return self._handle 114 115 def _read_variable_op(self): 116 if context.executing_eagerly(): 117 return self.get_var_on_current_device().value() 118 else: 119 return super(PackedDistributedVariable, self)._read_variable_op() 120 121 def value(self): 122 return self._read_variable_op() 123 124 def is_initialized(self, name=None): 125 if context.executing_eagerly(): 126 result = self._distributed_variables[0].is_initialized() 127 for v in self._distributed_variables[1:-1]: 128 result = math_ops.logical_and(result, v.is_initialized()) 129 result = math_ops.logical_and( 130 result, self._distributed_variables[-1].is_initialized(), name=name) 131 else: 132 with ops.device(self._devices[0]): 133 result = super(PackedDistributedVariable, self).is_initialized(name) 134 for d in self._devices[1:-1]: 135 with ops.device(d): 136 initialized = super(PackedDistributedVariable, 137 self).is_initialized(name) 138 result = math_ops.logical_and(result, initialized) 139 with ops.device(self._devices[-1]): 140 initialized = super(PackedDistributedVariable, 141 self).is_initialized(name) 142 result = math_ops.logical_and(result, initialized, name=name) 143 return result 144 145 def _update(self, update_fn, value, **kwargs): 146 if context.executing_eagerly(): 147 return update_fn(self.get_var_on_current_device(), value, **kwargs) 148 else: 149 return update_fn(super(PackedDistributedVariable, self), value, **kwargs) 150 151 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 152 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 153 return self._update( 154 update_fn=assign_sub_fn, 155 value=delta, 156 use_locking=use_locking, 157 name=name, 158 read_value=read_value) 159 160 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 161 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 162 return self._update( 163 update_fn=assign_add_fn, 164 value=delta, 165 use_locking=use_locking, 166 name=name, 167 read_value=read_value) 168 169 def assign(self, value, use_locking=None, name=None, read_value=True): 170 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 171 return self._update( 172 update_fn=assign_fn, 173 value=value, 174 use_locking=use_locking, 175 name=name, 176 read_value=read_value) 177 178 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 179 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) 180 return self._update( 181 update_fn=scatter_sub_fn, 182 value=sparse_delta, 183 use_locking=use_locking, 184 name=name) 185 186 def scatter_add(self, sparse_delta, use_locking=False, name=None): 187 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) 188 return self._update( 189 update_fn=scatter_add_fn, 190 value=sparse_delta, 191 use_locking=use_locking, 192 name=name) 193 194 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 195 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) 196 return self._update( 197 update_fn=scatter_mul_fn, 198 value=sparse_delta, 199 use_locking=use_locking, 200 name=name) 201 202 def scatter_div(self, sparse_delta, use_locking=False, name=None): 203 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) 204 return self._update( 205 update_fn=scatter_div_fn, 206 value=sparse_delta, 207 use_locking=use_locking, 208 name=name) 209 210 def scatter_min(self, sparse_delta, use_locking=False, name=None): 211 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) 212 return self._update( 213 update_fn=scatter_min_fn, 214 value=sparse_delta, 215 use_locking=use_locking, 216 name=name) 217 218 def scatter_max(self, sparse_delta, use_locking=False, name=None): 219 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) 220 return self._update( 221 update_fn=scatter_max_fn, 222 value=sparse_delta, 223 use_locking=use_locking, 224 name=name) 225 226 def scatter_update(self, sparse_delta, use_locking=False, name=None): 227 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) 228 return self._update( 229 update_fn=scatter_update_fn, 230 value=sparse_delta, 231 use_locking=use_locking, 232 name=name) 233 234 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 235 if context.executing_eagerly(): 236 return self.get_var_on_current_device()._dense_var_to_tensor( # pylint: disable=protected-access 237 dtype=dtype, 238 name=name, 239 as_ref=as_ref) 240 else: 241 return super(PackedDistributedVariable, self)._dense_var_to_tensor( # pylint: disable=protected-access 242 dtype=dtype, 243 name=name, 244 as_ref=as_ref) 245 246 247class PackedVarAndDevice(object): 248 """Holds a packed distributed variable and a device.""" 249 250 def __init__(self, var, device): 251 self._var = var 252 self._device = device 253 254 def __getattr__(self, name): 255 # Exceptions raised inside the contextmanager can cause a reference 256 # cycle.[1] The cycle involves the current frame, which holds the reference 257 # to the outer frame. Tensorflow, e.g. iterators, relies on object 258 # finalizers to clean up resources. Such references prevents the resource 259 # from being deleted and can cause leaks and errors. One corner the case is 260 # that iterators are kept alive and the garbage collector happens to run 261 # after auto control dependencies; this causes the deletion to lose the 262 # control dependencies to operations that uses such resources. 263 # 264 # Catch and re-raise the exception seems to workaround the issue. 265 # 266 # [1] https://bugs.python.org/issue43533 267 try: 268 with ops.device(self._device): 269 return getattr(self._var, name) 270 except: # pylint: disable=try-except-raise 271 raise 272 273 def var(self): 274 return self._var 275 276 def value(self): 277 with ops.device(self._device): 278 return self._var.value() 279 280 def read_value(self): 281 with ops.device(self._device): 282 return self._var.read_value() 283 284 @property 285 def initial_value(self): 286 return self._var.initial_value(self._device) 287 288 def initialized_value(self): 289 with ops.device(self._device): 290 return self._var.initialized_value() 291 292 @property 293 def device(self): 294 return self._device 295 296 @property 297 def handle(self): 298 with ops.device(self._device): 299 return self._var.handle 300 301 def on_device_handle(self): 302 with ops.device(self._device): 303 return self._var.get_var_on_current_device().handle 304 305 @property 306 def op(self): 307 with ops.device(self._device): 308 return self._var.op 309 310 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 311 with ops.device(self._device): 312 return self._var.assign_sub(delta, use_locking, name, read_value) 313 314 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 315 with ops.device(self._device): 316 return self._var.assign_add(delta, use_locking, name, read_value) 317 318 def assign(self, value, use_locking=None, name=None, read_value=True): 319 with ops.device(self._device): 320 return self._var.assign(value, use_locking, name, read_value) 321 322 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 323 with ops.device(self._device): 324 return self._var.scatter_sub(sparse_delta, use_locking, name) 325 326 def scatter_add(self, sparse_delta, use_locking=False, name=None): 327 with ops.device(self._device): 328 return self._var.scatter_add(sparse_delta, use_locking, name) 329 330 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 331 with ops.device(self._device): 332 return self._var.scatter_mul(sparse_delta, use_locking, name) 333 334 def scatter_div(self, sparse_delta, use_locking=False, name=None): 335 with ops.device(self._device): 336 return self._var.scatter_div(sparse_delta, use_locking, name) 337 338 def scatter_min(self, sparse_delta, use_locking=False, name=None): 339 with ops.device(self._device): 340 return self._var.scatter_min(sparse_delta, use_locking, name) 341 342 def scatter_max(self, sparse_delta, use_locking=False, name=None): 343 with ops.device(self._device): 344 return self._var.scatter_max(sparse_delta, use_locking, name) 345 346 def scatter_update(self, sparse_delta, use_locking=False, name=None): 347 with ops.device(self._device): 348 return self._var.scatter_update(sparse_delta, use_locking, name) 349 350 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 351 with ops.device(self._device): 352 return self._var._dense_var_to_tensor( # pylint: disable=protected-access 353 dtype=dtype, 354 name=name, 355 as_ref=as_ref) 356 357 def _as_graph_element(self): 358 return self._var._as_graph_element() # pylint: disable=protected-access 359 360 361def _tensor_conversion_packed_var_and_device(var, 362 dtype=None, 363 name=None, 364 as_ref=False): 365 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 366 367 368ops.register_tensor_conversion_function( 369 PackedVarAndDevice, _tensor_conversion_packed_var_and_device) 370