1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Standard functions for creating slots. 17 18A slot is a `Variable` created with the same shape as a primary variable or 19`Tensor`. A slot is always scoped in the namespace of the primary object and 20typically has the same device and type. 21 22Slots are typically used as accumulators to track values associated with 23the primary object: 24 25```python 26# Optimizers can create a slot for each variable to track accumulators 27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs} 28for var in vs: 29 apply_momentum(var, accumulators[var], lr, grad, momentum_tensor) 30 31# Slots can also be used for moving averages 32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg") 33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay)) 34``` 35""" 36# pylint: disable=g-bad-name 37 38from __future__ import absolute_import 39from __future__ import division 40from __future__ import print_function 41 42from tensorflow.python.distribute import distribution_strategy_context 43from tensorflow.python.eager import context 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import init_ops 46from tensorflow.python.ops import resource_variable_ops 47from tensorflow.python.ops import variable_scope 48from tensorflow.python.ops import variables 49 50 51def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): 52 """Helper function for creating a slot variable.""" 53 54 # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current 55 # scope. 56 current_partitioner = variable_scope.get_variable_scope().partitioner 57 variable_scope.get_variable_scope().set_partitioner(None) 58 # When init from val instead of callable initializer, the shape is expected to 59 # be None, not <unknown> or any fully defined shape. 60 shape = shape if callable(val) else None 61 if resource_variable_ops.is_resource_variable(primary): 62 use_resource = True 63 elif isinstance(primary, variables.RefVariable): 64 use_resource = False 65 else: 66 use_resource = None 67 slot = variable_scope.get_variable( 68 scope, 69 initializer=val, 70 trainable=False, 71 use_resource=use_resource, 72 shape=shape, 73 dtype=dtype, 74 validate_shape=validate_shape) 75 variable_scope.get_variable_scope().set_partitioner(current_partitioner) 76 77 # pylint: disable=protected-access 78 if isinstance(primary, variables.Variable) and primary._save_slice_info: 79 # Primary is a partitioned variable, so we need to also indicate that 80 # the slot is a partitioned variable. Slots have the same partitioning 81 # as their primaries. 82 # For examples when using AdamOptimizer in linear model, slot.name 83 # here can be "linear//weights/Adam:0", while primary.op.name is 84 # "linear//weight". We want to get 'Adam' as real_slot_name, so we 85 # remove "'linear//weight' + '/'" and ':0'. 86 real_slot_name = slot.name[len(primary.op.name + "/"):-2] 87 slice_info = primary._save_slice_info 88 slot._set_save_slice_info(variables.Variable.SaveSliceInfo( 89 slice_info.full_name + "/" + real_slot_name, 90 slice_info.full_shape[:], 91 slice_info.var_offset[:], 92 slice_info.var_shape[:])) 93 # pylint: enable=protected-access 94 return slot 95 96 97def create_slot(primary, val, name, colocate_with_primary=True): 98 """Create a slot initialized to the given value. 99 100 The type of the slot is determined by the given value. 101 102 Args: 103 primary: The primary `Variable` or `Tensor`. 104 val: A `Tensor` specifying the initial value of the slot. 105 name: Name to use for the slot variable. 106 colocate_with_primary: Boolean. If True the slot is located 107 on the same device as `primary`. 108 109 Returns: 110 A `Variable` object. 111 """ 112 # Scope the slot name in the namespace of the primary variable. 113 # Set "primary.op.name + '/' + name" as default name, so the scope name of 114 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 115 # and the same name has been previously used, the scope name will add '_N' 116 # as suffix for unique identifications. 117 validate_shape = val.get_shape().is_fully_defined() 118 if context.executing_eagerly(): 119 prefix = primary._shared_name # pylint: disable=protected-access 120 else: 121 prefix = primary.op.name 122 with variable_scope.variable_scope(None, prefix + "/" + name): 123 if colocate_with_primary: 124 distribution_strategy = distribution_strategy_context.get_strategy() 125 with distribution_strategy.extended.colocate_vars_with(primary): 126 return _create_slot_var(primary, val, "", validate_shape, None, None) 127 else: 128 return _create_slot_var(primary, val, "", validate_shape, None, None) 129 130 131def create_slot_with_initializer(primary, initializer, shape, dtype, name, 132 colocate_with_primary=True): 133 """Creates a slot initialized using an `Initializer`. 134 135 The type of the slot is determined by the given value. 136 137 Args: 138 primary: The primary `Variable` or `Tensor`. 139 initializer: An `Initializer`. The initial value of the slot. 140 shape: Shape of the initial value of the slot. 141 dtype: Type of the value of the slot. 142 name: Name to use for the slot variable. 143 colocate_with_primary: Boolean. If True the slot is located 144 on the same device as `primary`. 145 146 Returns: 147 A `Variable` object. 148 """ 149 # Scope the slot name in the namespace of the primary variable. 150 # Set "primary.op.name + '/' + name" as default name, so the scope name of 151 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 152 # and the same name has been previously used, the scope name will add '_N' 153 # as suffix for unique identifications. 154 validate_shape = shape.is_fully_defined() 155 if context.executing_eagerly(): 156 prefix = primary._shared_name # pylint: disable=protected-access 157 else: 158 prefix = primary.op.name 159 with variable_scope.variable_scope(None, prefix + "/" + name): 160 if colocate_with_primary: 161 distribution_strategy = distribution_strategy_context.get_strategy() 162 with distribution_strategy.extended.colocate_vars_with(primary): 163 return _create_slot_var(primary, initializer, "", validate_shape, shape, 164 dtype) 165 else: 166 return _create_slot_var(primary, initializer, "", validate_shape, shape, 167 dtype) 168 169 170def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True): 171 """Create a slot initialized to 0 with same shape as the primary object. 172 173 Args: 174 primary: The primary `Variable` or `Tensor`. 175 name: Name to use for the slot variable. 176 dtype: Type of the slot variable. Defaults to the type of `primary`. 177 colocate_with_primary: Boolean. If True the slot is located 178 on the same device as `primary`. 179 180 Returns: 181 A `Variable` object. 182 """ 183 if dtype is None: 184 dtype = primary.dtype 185 slot_shape = primary.get_shape() 186 if slot_shape.is_fully_defined(): 187 initializer = init_ops.zeros_initializer(dtype) 188 return create_slot_with_initializer( 189 primary, initializer, slot_shape, dtype, name, 190 colocate_with_primary=colocate_with_primary) 191 else: 192 if isinstance(primary, variables.Variable): 193 slot_shape = array_ops.shape(primary.initialized_value()) 194 else: 195 slot_shape = array_ops.shape(primary) 196 val = array_ops.zeros(slot_shape, dtype=dtype) 197 return create_slot(primary, val, name, 198 colocate_with_primary=colocate_with_primary) 199