• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Standard functions for creating slots.
17
18A slot is a `Variable` created with the same shape as a primary variable or
19`Tensor`. A slot is always scoped in the namespace of the primary object and
20typically has the same device and type.
21
22Slots are typically used as accumulators to track values associated with
23the primary object:
24
25```python
26# Optimizers can create a slot for each variable to track accumulators
27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs}
28for var in vs:
29  apply_momentum(var, accumulators[var], lr, grad, momentum_tensor)
30
31# Slots can also be used for moving averages
32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg")
33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay))
34```
35"""
36# pylint: disable=g-bad-name
37
38from __future__ import absolute_import
39from __future__ import division
40from __future__ import print_function
41
42from tensorflow.python.distribute import distribution_strategy_context
43from tensorflow.python.eager import context
44from tensorflow.python.ops import array_ops
45from tensorflow.python.ops import init_ops
46from tensorflow.python.ops import resource_variable_ops
47from tensorflow.python.ops import variable_scope
48from tensorflow.python.ops import variables
49
50
51def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
52  """Helper function for creating a slot variable."""
53
54  # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
55  # scope.
56  current_partitioner = variable_scope.get_variable_scope().partitioner
57  variable_scope.get_variable_scope().set_partitioner(None)
58  # When init from val instead of callable initializer, the shape is expected to
59  # be None, not <unknown> or any fully defined shape.
60  shape = shape if callable(val) else None
61  if resource_variable_ops.is_resource_variable(primary):
62    use_resource = True
63  elif isinstance(primary, variables.RefVariable):
64    use_resource = False
65  else:
66    use_resource = None
67  slot = variable_scope.get_variable(
68      scope,
69      initializer=val,
70      trainable=False,
71      use_resource=use_resource,
72      shape=shape,
73      dtype=dtype,
74      validate_shape=validate_shape)
75  variable_scope.get_variable_scope().set_partitioner(current_partitioner)
76
77  # pylint: disable=protected-access
78  if isinstance(primary, variables.Variable) and primary._save_slice_info:
79    # Primary is a partitioned variable, so we need to also indicate that
80    # the slot is a partitioned variable.  Slots have the same partitioning
81    # as their primaries.
82    # For examples when using AdamOptimizer in linear model, slot.name
83    # here can be "linear//weights/Adam:0", while primary.op.name is
84    # "linear//weight". We want to get 'Adam' as real_slot_name, so we
85    # remove "'linear//weight' + '/'" and ':0'.
86    real_slot_name = slot.name[len(primary.op.name + "/"):-2]
87    slice_info = primary._save_slice_info
88    slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
89        slice_info.full_name + "/" + real_slot_name,
90        slice_info.full_shape[:],
91        slice_info.var_offset[:],
92        slice_info.var_shape[:]))
93  # pylint: enable=protected-access
94  return slot
95
96
97def create_slot(primary, val, name, colocate_with_primary=True):
98  """Create a slot initialized to the given value.
99
100  The type of the slot is determined by the given value.
101
102  Args:
103    primary: The primary `Variable` or `Tensor`.
104    val: A `Tensor` specifying the initial value of the slot.
105    name: Name to use for the slot variable.
106    colocate_with_primary: Boolean.  If True the slot is located
107      on the same device as `primary`.
108
109  Returns:
110    A `Variable` object.
111  """
112  # Scope the slot name in the namespace of the primary variable.
113  # Set "primary.op.name + '/' + name" as default name, so the scope name of
114  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
115  # and the same name has been previously used, the scope name will add '_N'
116  # as suffix for unique identifications.
117  validate_shape = val.get_shape().is_fully_defined()
118  if context.executing_eagerly():
119    prefix = primary._shared_name  # pylint: disable=protected-access
120  else:
121    prefix = primary.op.name
122  with variable_scope.variable_scope(None, prefix + "/" + name):
123    if colocate_with_primary:
124      distribution_strategy = distribution_strategy_context.get_strategy()
125      with distribution_strategy.extended.colocate_vars_with(primary):
126        return _create_slot_var(primary, val, "", validate_shape, None, None)
127    else:
128      return _create_slot_var(primary, val, "", validate_shape, None, None)
129
130
131def create_slot_with_initializer(primary, initializer, shape, dtype, name,
132                                 colocate_with_primary=True):
133  """Creates a slot initialized using an `Initializer`.
134
135  The type of the slot is determined by the given value.
136
137  Args:
138    primary: The primary `Variable` or `Tensor`.
139    initializer: An `Initializer`.  The initial value of the slot.
140    shape: Shape of the initial value of the slot.
141    dtype: Type of the value of the slot.
142    name: Name to use for the slot variable.
143    colocate_with_primary: Boolean.  If True the slot is located
144      on the same device as `primary`.
145
146  Returns:
147    A `Variable` object.
148  """
149  # Scope the slot name in the namespace of the primary variable.
150  # Set "primary.op.name + '/' + name" as default name, so the scope name of
151  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
152  # and the same name has been previously used, the scope name will add '_N'
153  # as suffix for unique identifications.
154  validate_shape = shape.is_fully_defined()
155  if context.executing_eagerly():
156    prefix = primary._shared_name  # pylint: disable=protected-access
157  else:
158    prefix = primary.op.name
159  with variable_scope.variable_scope(None, prefix + "/" + name):
160    if colocate_with_primary:
161      distribution_strategy = distribution_strategy_context.get_strategy()
162      with distribution_strategy.extended.colocate_vars_with(primary):
163        return _create_slot_var(primary, initializer, "", validate_shape, shape,
164                                dtype)
165    else:
166      return _create_slot_var(primary, initializer, "", validate_shape, shape,
167                              dtype)
168
169
170def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
171  """Create a slot initialized to 0 with same shape as the primary object.
172
173  Args:
174    primary: The primary `Variable` or `Tensor`.
175    name: Name to use for the slot variable.
176    dtype: Type of the slot variable.  Defaults to the type of `primary`.
177    colocate_with_primary: Boolean.  If True the slot is located
178      on the same device as `primary`.
179
180  Returns:
181    A `Variable` object.
182  """
183  if dtype is None:
184    dtype = primary.dtype
185  slot_shape = primary.get_shape()
186  if slot_shape.is_fully_defined():
187    initializer = init_ops.zeros_initializer(dtype)
188    return create_slot_with_initializer(
189        primary, initializer, slot_shape, dtype, name,
190        colocate_with_primary=colocate_with_primary)
191  else:
192    if isinstance(primary, variables.Variable):
193      slot_shape = array_ops.shape(primary.initialized_value())
194    else:
195      slot_shape = array_ops.shape(primary)
196    val = array_ops.zeros(slot_shape, dtype=dtype)
197    return create_slot(primary, val, name,
198                       colocate_with_primary=colocate_with_primary)
199