• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""Optional helper for gradient handling."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import variable_scope
27from tensorflow.python.ops import variables
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.tpu.ops import tpu_ops
30
31
32def get_gradients_through_compute_gradients(optimizer, loss, activations):
33  """Compute gradients to send to TPU embedding.
34
35  Args:
36    optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
37      Used to call compute_gradients().
38    loss: a Tensor to call optimizer.compute_gradients() on.
39    activations: an OrderedDict mapping feature_name to Tensors of activations.
40
41  Returns:
42    An OrderedDict mapping from feature name Strings to Tensors of gradients of
43      the loss wrt the activations of the features.
44  """
45  activation_list = activations.values()
46  grads_and_vars = optimizer.compute_gradients(loss, activation_list)
47  grads = [grad for grad, _ in grads_and_vars]
48  feature_to_gradient_dict = collections.OrderedDict(
49      zip(activations.keys(), grads))
50  return feature_to_gradient_dict
51
52
53def create_dummy_table_variables(tpu_embedding):
54  """Create dummy embedding table variables.
55
56  The sole purpose of these dummy variables are to trigger gradient
57  calculation wrt them so that the gradients wrt activation can be captured
58  and later sent to TPU embedding.
59
60  Args:
61    tpu_embedding: TPUEmbedding, dummy table variables will be created for use
62      with tpu_embedding.
63
64  Returns:
65    A tuple of dummy variables and their initializer.
66
67  Raises:
68    RuntimeError: if collection to store gradients already exists and is not
69    empty.
70  """
71  dummy_table_variables = collections.OrderedDict()
72  for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
73    dummy_table_variables[table] = (
74        # Explicitly specifying collections prevents this variable from
75        # being added to the GLOBAL_VARIABLES collection, so that Saver()
76        # ignores it.
77        # But Tensorflow optimizer creates slot variable for these dummy
78        # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
79        # which will be in GLOBAL_VARIABLES collection,
80        variable_scope.get_variable(
81            'tpu_embedding_dummy_table_variable_{}'.format(table),
82            dtype=dtypes.float32,
83            shape=[1],
84            use_resource=True,
85            trainable=True,
86            collections=['tpu_embedding_dummy_table_variables']))
87
88    g = ops.get_default_graph()
89    table_gradients = g.get_collection_ref(
90        'tpu_embedding_gradients_table_{}'.format(table_id))
91    if table_gradients:
92      raise RuntimeError(
93          'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
94    num_features = len(tpu_embedding.table_to_features_dict[table])
95    table_gradients.extend([None for _ in range(num_features)])
96
97  return (dummy_table_variables,
98          variables.variables_initializer(
99              dummy_table_variables.values(),
100              name='tpu_embedding_dummy_table_variables_init'))
101
102
103def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
104                                              dummy_table_variables):
105  """Have activations depend on dummy table variables for gradient intercept.
106
107  Args:
108    tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
109      tpu_embedding.
110    activations: An OrderedDict of feature name String to activation tensors.
111    dummy_table_variables: An OrderedDict of table name String to dummy table
112      variables.
113
114  Returns:
115    An OrderedDict of feature name String to activation tensors, which can be
116      used just as the activations input.
117  """
118  new_activations = collections.OrderedDict()
119  for feature in activations:
120    table = tpu_embedding.feature_to_config_dict[feature].table_id
121    new_activations[feature] = tpu_ops.tpu_embedding_activations(
122        dummy_table_variables[table],
123        activations[feature],
124        table_id=list(tpu_embedding.table_to_config_dict).index(table),
125        lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
126  return new_activations
127
128
129def get_gradients_through_dummy_table_variables(tpu_embedding):
130  """Get gradients wrt the activations of each feature.
131
132  Args:
133    tpu_embedding: TPUEmbedding, create dummy table variable to be used with
134      tpu_embedding.
135
136  Returns:
137    An OrderedDict mapping feature name to gradient.
138
139  Raises:
140    ValueError: if some gradients are not defined.
141  """
142  g = ops.get_default_graph()
143  gradients_found = False
144  for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
145    table_gradients = g.get_collection(
146        'tpu_embedding_gradients_table_{}'.format(table_id))
147    if any(gradient is None for gradient in table_gradients):
148      # TODO(bfontain): create a white-list for optimizers which are compatible
149      # with `tf.stop_gradient`.
150      logging.warn(
151          'Table {} with id {} has undefined gradients: this is probably '
152          'because the model asked TPUEmbedding to compute activations that '
153          'were not used, or tf.stop_gradient() is applied. Gradients of zeros '
154          'are sent back to TPUEmbedding instead. Gradients of zeros and no '
155          'gradients are equivalent for SGD, AdaGrad, FTRL, etc, but '
156          'might differ for other optimizers due to implementation of TPU '
157          'embedding optimizers.'.format(table, table_id))
158    gradients_found = gradients_found or any(
159        gradient is not None for gradient in table_gradients)
160
161  if not gradients_found:
162    logging.warn(
163        'All tables have undefined gradients: this is probably because the '
164        'model asked TPUEmbedding to compute activations that were not used. '
165        'If all TPUEmbedding features have stop_gradients, consider using the '
166        'INFERENCE mode instead.')
167
168  feature_to_gradient_dict = collections.OrderedDict()
169  for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
170    table_gradients = g.get_collection(
171        'tpu_embedding_gradients_table_{}'.format(table_id))
172    for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
173                                 table_gradients):
174      if gradient is not None:
175        feature_to_gradient_dict[feature] = gradient
176      else:
177        dimension = tpu_embedding.table_to_config_dict[table].dimension
178        batch_size = tpu_embedding.batch_size_per_core
179        max_sequence_length = (
180            tpu_embedding.feature_to_config_dict[feature].max_sequence_length)
181        if max_sequence_length:
182          feature_to_gradient_dict[feature] = array_ops.zeros(
183              [batch_size, max_sequence_length, dimension])
184        else:
185          feature_to_gradient_dict[feature] = array_ops.zeros(
186              [batch_size, dimension])
187
188  return feature_to_gradient_dict
189