1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""Optional helper for gradient handling.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import variable_scope 27from tensorflow.python.ops import variables 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.tpu.ops import tpu_ops 30 31 32def get_gradients_through_compute_gradients(optimizer, loss, activations): 33 """Compute gradients to send to TPU embedding. 34 35 Args: 36 optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer. 37 Used to call compute_gradients(). 38 loss: a Tensor to call optimizer.compute_gradients() on. 39 activations: an OrderedDict mapping feature_name to Tensors of activations. 40 41 Returns: 42 An OrderedDict mapping from feature name Strings to Tensors of gradients of 43 the loss wrt the activations of the features. 44 """ 45 activation_list = activations.values() 46 grads_and_vars = optimizer.compute_gradients(loss, activation_list) 47 grads = [grad for grad, _ in grads_and_vars] 48 feature_to_gradient_dict = collections.OrderedDict( 49 zip(activations.keys(), grads)) 50 return feature_to_gradient_dict 51 52 53def create_dummy_table_variables(tpu_embedding): 54 """Create dummy embedding table variables. 55 56 The sole purpose of these dummy variables are to trigger gradient 57 calculation wrt them so that the gradients wrt activation can be captured 58 and later sent to TPU embedding. 59 60 Args: 61 tpu_embedding: TPUEmbedding, dummy table variables will be created for use 62 with tpu_embedding. 63 64 Returns: 65 A tuple of dummy variables and their initializer. 66 67 Raises: 68 RuntimeError: if collection to store gradients already exists and is not 69 empty. 70 """ 71 dummy_table_variables = collections.OrderedDict() 72 for table_id, table in enumerate(tpu_embedding.table_to_features_dict): 73 dummy_table_variables[table] = ( 74 # Explicitly specifying collections prevents this variable from 75 # being added to the GLOBAL_VARIABLES collection, so that Saver() 76 # ignores it. 77 # But Tensorflow optimizer creates slot variable for these dummy 78 # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}, 79 # which will be in GLOBAL_VARIABLES collection, 80 variable_scope.get_variable( 81 'tpu_embedding_dummy_table_variable_{}'.format(table), 82 dtype=dtypes.float32, 83 shape=[1], 84 use_resource=True, 85 trainable=True, 86 collections=['tpu_embedding_dummy_table_variables'])) 87 88 g = ops.get_default_graph() 89 table_gradients = g.get_collection_ref( 90 'tpu_embedding_gradients_table_{}'.format(table_id)) 91 if table_gradients: 92 raise RuntimeError( 93 'tpu_embedding_gradients_table_{} is not empty.'.format(table_id)) 94 num_features = len(tpu_embedding.table_to_features_dict[table]) 95 table_gradients.extend([None for _ in range(num_features)]) 96 97 return (dummy_table_variables, 98 variables.variables_initializer( 99 dummy_table_variables.values(), 100 name='tpu_embedding_dummy_table_variables_init')) 101 102 103def hook_dummy_table_variables_to_activations(tpu_embedding, activations, 104 dummy_table_variables): 105 """Have activations depend on dummy table variables for gradient intercept. 106 107 Args: 108 tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from 109 tpu_embedding. 110 activations: An OrderedDict of feature name String to activation tensors. 111 dummy_table_variables: An OrderedDict of table name String to dummy table 112 variables. 113 114 Returns: 115 An OrderedDict of feature name String to activation tensors, which can be 116 used just as the activations input. 117 """ 118 new_activations = collections.OrderedDict() 119 for feature in activations: 120 table = tpu_embedding.feature_to_config_dict[feature].table_id 121 new_activations[feature] = tpu_ops.tpu_embedding_activations( 122 dummy_table_variables[table], 123 activations[feature], 124 table_id=list(tpu_embedding.table_to_config_dict).index(table), 125 lookup_id=tpu_embedding.table_to_features_dict[table].index(feature)) 126 return new_activations 127 128 129def get_gradients_through_dummy_table_variables(tpu_embedding): 130 """Get gradients wrt the activations of each feature. 131 132 Args: 133 tpu_embedding: TPUEmbedding, create dummy table variable to be used with 134 tpu_embedding. 135 136 Returns: 137 An OrderedDict mapping feature name to gradient. 138 139 Raises: 140 ValueError: if some gradients are not defined. 141 """ 142 g = ops.get_default_graph() 143 gradients_found = False 144 for table_id, table in enumerate(tpu_embedding.table_to_config_dict): 145 table_gradients = g.get_collection( 146 'tpu_embedding_gradients_table_{}'.format(table_id)) 147 if any(gradient is None for gradient in table_gradients): 148 # TODO(bfontain): create a white-list for optimizers which are compatible 149 # with `tf.stop_gradient`. 150 logging.warn( 151 'Table {} with id {} has undefined gradients: this is probably ' 152 'because the model asked TPUEmbedding to compute activations that ' 153 'were not used, or tf.stop_gradient() is applied. Gradients of zeros ' 154 'are sent back to TPUEmbedding instead. Gradients of zeros and no ' 155 'gradients are equivalent for SGD, AdaGrad, FTRL, etc, but ' 156 'might differ for other optimizers due to implementation of TPU ' 157 'embedding optimizers.'.format(table, table_id)) 158 gradients_found = gradients_found or any( 159 gradient is not None for gradient in table_gradients) 160 161 if not gradients_found: 162 logging.warn( 163 'All tables have undefined gradients: this is probably because the ' 164 'model asked TPUEmbedding to compute activations that were not used. ' 165 'If all TPUEmbedding features have stop_gradients, consider using the ' 166 'INFERENCE mode instead.') 167 168 feature_to_gradient_dict = collections.OrderedDict() 169 for table_id, table in enumerate(tpu_embedding.table_to_config_dict): 170 table_gradients = g.get_collection( 171 'tpu_embedding_gradients_table_{}'.format(table_id)) 172 for feature, gradient in zip(tpu_embedding.table_to_features_dict[table], 173 table_gradients): 174 if gradient is not None: 175 feature_to_gradient_dict[feature] = gradient 176 else: 177 dimension = tpu_embedding.table_to_config_dict[table].dimension 178 batch_size = tpu_embedding.batch_size_per_core 179 max_sequence_length = ( 180 tpu_embedding.feature_to_config_dict[feature].max_sequence_length) 181 if max_sequence_length: 182 feature_to_gradient_dict[feature] = array_ops.zeros( 183 [batch_size, max_sequence_length, dimension]) 184 else: 185 feature_to_gradient_dict[feature] = array_ops.zeros( 186 [batch_size, dimension]) 187 188 return feature_to_gradient_dict 189