• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""Optional helper for gradient handling."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import variable_scope
26from tensorflow.python.ops import variables
27from tensorflow.python.tpu.ops import tpu_ops
28
29
30def get_gradients_through_compute_gradients(optimizer, loss, activations):
31  """Compute gradients to send to TPU embedding.
32
33  Args:
34    optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
35      Used to call compute_gradients().
36    loss: a Tensor to call optimizer.compute_gradients() on.
37    activations: an OrderedDict mapping feature_name to Tensors of activations.
38
39  Returns:
40    An OrderedDict mapping from feature name Strings to Tensors of gradients of
41      the loss wrt the activations of the features.
42  """
43  activation_list = activations.values()
44  grads_and_vars = optimizer.compute_gradients(loss, activation_list)
45  grads = [grad for grad, _ in grads_and_vars]
46  feature_to_gradient_dict = collections.OrderedDict(
47      zip(activations.keys(), grads))
48  return feature_to_gradient_dict
49
50
51def create_dummy_table_variables(tpu_embedding):
52  """Create dummy embedding table variables.
53
54  The sole purpose of these dummy variables are to trigger gradient
55  calcuation wrt them so that the gradients wrt activation can be captured
56  and later sent to TPU embedding.
57
58  Args:
59    tpu_embedding: TPUEmbedding, dummy table variables will be created for use
60      with tpu_embedding.
61
62  Returns:
63    A tuple of dummy variables and their initializer.
64
65  Raises:
66    RuntimeError: if collection to store gradients already exists and is not
67    empty.
68  """
69  dummy_table_variables = collections.OrderedDict()
70  for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
71    dummy_table_variables[table] = (
72        # Explicitly specifying collections prevents this variable from
73        # being added to the GLOBAL_VARIABLES collection, so that Saver()
74        # ignores it.
75        # But Tensorflow optimizer creates slot variable for these dummy
76        # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
77        # which will be in GLOBAL_VARIABLES collection,
78        variable_scope.get_variable(
79            'tpu_embedding_dummy_table_variable_{}'.format(table),
80            dtype=dtypes.float32,
81            shape=[1],
82            use_resource=True,
83            trainable=True,
84            collections=['tpu_embedding_dummy_table_variables']))
85
86    g = ops.get_default_graph()
87    table_gradients = g.get_collection_ref(
88        'tpu_embedding_gradients_table_{}'.format(table_id))
89    if table_gradients:
90      raise RuntimeError(
91          'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
92    table_gradients.extend(
93        [None] * len(tpu_embedding.table_to_features_dict[table]))
94
95  return (dummy_table_variables,
96          variables.variables_initializer(
97              dummy_table_variables.values(),
98              name='tpu_embedding_dummy_table_variables_init'))
99
100
101def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
102                                              dummy_table_variables):
103  """Have activations depend on dummy table variables for gradient intercept.
104
105  Args:
106    tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
107      tpu_embedding.
108    activations: An OrderedDict of feature name String to activation tensors.
109    dummy_table_variables: An OrderedDict of table name String to dummy table
110      variables.
111
112  Returns:
113    An OrderedDict of feature name String to activation tensors, which can be
114      used just as the activations input.
115  """
116  new_activations = collections.OrderedDict()
117  for feature in activations:
118    table = tpu_embedding.feature_to_table_dict[feature]
119    new_activations[feature] = tpu_ops.tpu_embedding_activations(
120        dummy_table_variables[table],
121        activations[feature],
122        table_id=tpu_embedding.table_to_config_dict.keys().index(table),
123        lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
124  return new_activations
125
126
127def get_gradients_through_dummy_table_variables(tpu_embedding):
128  """Get gradients wrt the activations of each feature.
129
130  Args:
131    tpu_embedding: TPUEmbedding, create dummy table variable to be used with
132      tpu_embedding.
133
134  Returns:
135    An OrderedDict mapping feature name to gradient.
136
137  Raises:
138    ValueError: if some gradients are not defined.
139  """
140  g = ops.get_default_graph()
141  feature_to_gradient_dict = collections.OrderedDict()
142  for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
143    table_gradients = g.get_collection(
144        'tpu_embedding_gradients_table_{}'.format(table_id))
145    if any(gradient is None for gradient in table_gradients):
146      raise ValueError(
147          'Table {} with id {} has undefined gradients: this is probably '
148          'because the model asked TPUEmbedding to compute activations that '
149          'were not used.'.format(table, table_id))
150    for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
151                                 table_gradients):
152      feature_to_gradient_dict[feature] = gradient
153  return feature_to_gradient_dict
154