1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Standard functions for creating slots. 17 18A slot is a `Variable` created with the same first m-dimension as a primary 19variable or `Tensor`. A slot is always scoped in the namespace of the primary 20object and typically has the same device and type. 21 22Slots are typically used as accumulators to track values associated with 23the primary object: 24 25```python 26# Optimizers can create a slot for each variable to track accumulators 27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs} 28for var in vs: 29 apply_momentum(var, accumulators[var], lr, grad, momentum_tensor) 30 31# Slots can also be used for moving averages 32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg") 33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay)) 34``` 35""" 36# pylint: disable=g-bad-name 37 38from __future__ import absolute_import 39from __future__ import division 40from __future__ import print_function 41 42from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 43from tensorflow.python.distribute import distribution_strategy_context 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import init_ops 46from tensorflow.python.ops import resource_variable_ops 47from tensorflow.python.ops import variable_scope 48from tensorflow.python.ops import variables 49 50 51def _create_slot_var(primary, 52 val, 53 scope, 54 validate_shape, 55 shape, 56 dtype, 57 *, 58 copy_xla_sharding=False): 59 """Helper function for creating a slot variable.""" 60 61 # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current 62 # scope. 63 current_partitioner = variable_scope.get_variable_scope().partitioner 64 variable_scope.get_variable_scope().set_partitioner(None) 65 # When init from val instead of callable initializer, the shape is expected to 66 # be None, not <unknown> or any fully defined shape. 67 shape = shape if callable(val) else None 68 if resource_variable_ops.is_resource_variable(primary): 69 use_resource = True 70 elif isinstance(primary, variables.RefVariable): 71 use_resource = False 72 else: 73 use_resource = None 74 slot = variable_scope.get_variable( 75 scope, 76 initializer=val, 77 trainable=False, 78 use_resource=use_resource, 79 shape=shape, 80 dtype=dtype, 81 validate_shape=validate_shape) 82 variable_scope.get_variable_scope().set_partitioner(current_partitioner) 83 84 # pylint: disable=protected-access 85 if isinstance(primary, variables.Variable) and primary._save_slice_info: 86 # Primary is a partitioned variable, so we need to also indicate that 87 # the slot is a partitioned variable. Slots have the same partitioning 88 # as their primaries. 89 # For examples when using AdamOptimizer in linear model, slot.name 90 # here can be "linear//weights/Adam:0", while primary.op.name is 91 # "linear//weight". We want to get 'Adam' as real_slot_name, so we 92 # remove "'linear//weight' + '/'" and ':0'. 93 real_slot_name = slot.name[len(primary.op.name + "/"):-2] 94 slice_info = primary._save_slice_info 95 # support slot's shape not same as primary's shape 96 # example: primary's shape = [10, 20, 30], slot's shape = 97 # None, [], [10], [10, 20] or [10, 20, 30] is allowed 98 # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary 99 # slot's shape = [], don't set slot's slice_info 100 # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims 101 n = slot.shape.ndims 102 if n is None or n > 0: 103 slot._set_save_slice_info( 104 variables.Variable.SaveSliceInfo( 105 slice_info.full_name + "/" + real_slot_name, 106 slice_info.full_shape[:n], slice_info.var_offset[:n], 107 slice_info.var_shape[:n])) 108 # pylint: enable=protected-access 109 110 # Copy XLA sharding attributes from primary. 111 if copy_xla_sharding: 112 slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False) 113 return slot 114 115 116def create_slot(primary, 117 val, 118 name, 119 colocate_with_primary=True, 120 *, 121 copy_xla_sharding=False): 122 """Create a slot initialized to the given value. 123 124 The type of the slot is determined by the given value. 125 126 Args: 127 primary: The primary `Variable` or `Tensor`. 128 val: A `Tensor` specifying the initial value of the slot. 129 name: Name to use for the slot variable. 130 colocate_with_primary: Boolean. If True the slot is located 131 on the same device as `primary`. 132 copy_xla_sharding: Boolean. If True also copies XLA sharding 133 from primary. 134 135 Returns: 136 A `Variable` object. 137 """ 138 # Scope the slot name in the namespace of the primary variable. 139 # Set primary's name + '/' + name as default name, so the scope name of 140 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 141 # and the same name has been previously used, the scope name will add '_N' 142 # as suffix for unique identifications. 143 validate_shape = val.get_shape().is_fully_defined() 144 if isinstance(primary, variables.Variable): 145 prefix = primary._shared_name # pylint: disable=protected-access 146 else: 147 prefix = primary.op.name 148 with variable_scope.variable_scope(None, prefix + "/" + name): 149 if colocate_with_primary: 150 distribution_strategy = distribution_strategy_context.get_strategy() 151 with distribution_strategy.extended.colocate_vars_with(primary): 152 return _create_slot_var( 153 primary, 154 val, 155 "", 156 validate_shape, 157 None, 158 None, 159 copy_xla_sharding=copy_xla_sharding) 160 else: 161 return _create_slot_var( 162 primary, 163 val, 164 "", 165 validate_shape, 166 None, 167 None, 168 copy_xla_sharding=copy_xla_sharding) 169 170 171def create_slot_with_initializer(primary, 172 initializer, 173 shape, 174 dtype, 175 name, 176 colocate_with_primary=True, 177 *, 178 copy_xla_sharding=False): 179 """Creates a slot initialized using an `Initializer`. 180 181 The type of the slot is determined by the given value. 182 183 Args: 184 primary: The primary `Variable` or `Tensor`. 185 initializer: An `Initializer`. The initial value of the slot. 186 shape: Shape of the initial value of the slot. 187 dtype: Type of the value of the slot. 188 name: Name to use for the slot variable. 189 colocate_with_primary: Boolean. If True the slot is located 190 on the same device as `primary`. 191 copy_xla_sharding: Boolean. If True also copies XLA sharding 192 from primary. 193 194 Returns: 195 A `Variable` object. 196 """ 197 # Scope the slot name in the namespace of the primary variable. 198 # Set "primary.op.name + '/' + name" as default name, so the scope name of 199 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 200 # and the same name has been previously used, the scope name will add '_N' 201 # as suffix for unique identifications. 202 validate_shape = shape.is_fully_defined() 203 if isinstance(primary, variables.Variable): 204 prefix = primary._shared_name # pylint: disable=protected-access 205 else: 206 prefix = primary.op.name 207 with variable_scope.variable_scope(None, prefix + "/" + name): 208 if colocate_with_primary: 209 distribution_strategy = distribution_strategy_context.get_strategy() 210 with distribution_strategy.extended.colocate_vars_with(primary): 211 return _create_slot_var( 212 primary, 213 initializer, 214 "", 215 validate_shape, 216 shape, 217 dtype, 218 copy_xla_sharding=copy_xla_sharding) 219 else: 220 return _create_slot_var( 221 primary, 222 initializer, 223 "", 224 validate_shape, 225 shape, 226 dtype, 227 copy_xla_sharding=copy_xla_sharding) 228 229 230def create_zeros_slot(primary, 231 name, 232 dtype=None, 233 colocate_with_primary=True, 234 *, 235 copy_xla_sharding=False): 236 """Create a slot initialized to 0 with same shape as the primary object. 237 238 Args: 239 primary: The primary `Variable` or `Tensor`. 240 name: Name to use for the slot variable. 241 dtype: Type of the slot variable. Defaults to the type of `primary`. 242 colocate_with_primary: Boolean. If True the slot is located 243 on the same device as `primary`. 244 copy_xla_sharding: Boolean. If True also copies XLA sharding 245 from primary. 246 247 Returns: 248 A `Variable` object. 249 """ 250 if dtype is None: 251 dtype = primary.dtype 252 slot_shape = primary.get_shape() 253 if slot_shape.is_fully_defined(): 254 initializer = init_ops.zeros_initializer() 255 return create_slot_with_initializer( 256 primary, 257 initializer, 258 slot_shape, 259 dtype, 260 name, 261 colocate_with_primary=colocate_with_primary, 262 copy_xla_sharding=copy_xla_sharding) 263 else: 264 if isinstance(primary, variables.Variable): 265 slot_shape = array_ops.shape(primary.initialized_value()) 266 else: 267 slot_shape = array_ops.shape(primary) 268 val = array_ops.zeros(slot_shape, dtype=dtype) 269 return create_slot( 270 primary, 271 val, 272 name, 273 colocate_with_primary=colocate_with_primary, 274 copy_xla_sharding=copy_xla_sharding) 275