1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""Optional helper for gradient handling.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import variable_scope 26from tensorflow.python.ops import variables 27from tensorflow.python.tpu.ops import tpu_ops 28 29 30def get_gradients_through_compute_gradients(optimizer, loss, activations): 31 """Compute gradients to send to TPU embedding. 32 33 Args: 34 optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer. 35 Used to call compute_gradients(). 36 loss: a Tensor to call optimizer.compute_gradients() on. 37 activations: an OrderedDict mapping feature_name to Tensors of activations. 38 39 Returns: 40 An OrderedDict mapping from feature name Strings to Tensors of gradients of 41 the loss wrt the activations of the features. 42 """ 43 activation_list = activations.values() 44 grads_and_vars = optimizer.compute_gradients(loss, activation_list) 45 grads = [grad for grad, _ in grads_and_vars] 46 feature_to_gradient_dict = collections.OrderedDict( 47 zip(activations.keys(), grads)) 48 return feature_to_gradient_dict 49 50 51def create_dummy_table_variables(tpu_embedding): 52 """Create dummy embedding table variables. 53 54 The sole purpose of these dummy variables are to trigger gradient 55 calcuation wrt them so that the gradients wrt activation can be captured 56 and later sent to TPU embedding. 57 58 Args: 59 tpu_embedding: TPUEmbedding, dummy table variables will be created for use 60 with tpu_embedding. 61 62 Returns: 63 A tuple of dummy variables and their initializer. 64 65 Raises: 66 RuntimeError: if collection to store gradients already exists and is not 67 empty. 68 """ 69 dummy_table_variables = collections.OrderedDict() 70 for table_id, table in enumerate(tpu_embedding.table_to_features_dict): 71 dummy_table_variables[table] = ( 72 # Explicitly specifying collections prevents this variable from 73 # being added to the GLOBAL_VARIABLES collection, so that Saver() 74 # ignores it. 75 # But Tensorflow optimizer creates slot variable for these dummy 76 # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}, 77 # which will be in GLOBAL_VARIABLES collection, 78 variable_scope.get_variable( 79 'tpu_embedding_dummy_table_variable_{}'.format(table), 80 dtype=dtypes.float32, 81 shape=[1], 82 use_resource=True, 83 trainable=True, 84 collections=['tpu_embedding_dummy_table_variables'])) 85 86 g = ops.get_default_graph() 87 table_gradients = g.get_collection_ref( 88 'tpu_embedding_gradients_table_{}'.format(table_id)) 89 if table_gradients: 90 raise RuntimeError( 91 'tpu_embedding_gradients_table_{} is not empty.'.format(table_id)) 92 table_gradients.extend( 93 [None] * len(tpu_embedding.table_to_features_dict[table])) 94 95 return (dummy_table_variables, 96 variables.variables_initializer( 97 dummy_table_variables.values(), 98 name='tpu_embedding_dummy_table_variables_init')) 99 100 101def hook_dummy_table_variables_to_activations(tpu_embedding, activations, 102 dummy_table_variables): 103 """Have activations depend on dummy table variables for gradient intercept. 104 105 Args: 106 tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from 107 tpu_embedding. 108 activations: An OrderedDict of feature name String to activation tensors. 109 dummy_table_variables: An OrderedDict of table name String to dummy table 110 variables. 111 112 Returns: 113 An OrderedDict of feature name String to activation tensors, which can be 114 used just as the activations input. 115 """ 116 new_activations = collections.OrderedDict() 117 for feature in activations: 118 table = tpu_embedding.feature_to_table_dict[feature] 119 new_activations[feature] = tpu_ops.tpu_embedding_activations( 120 dummy_table_variables[table], 121 activations[feature], 122 table_id=tpu_embedding.table_to_config_dict.keys().index(table), 123 lookup_id=tpu_embedding.table_to_features_dict[table].index(feature)) 124 return new_activations 125 126 127def get_gradients_through_dummy_table_variables(tpu_embedding): 128 """Get gradients wrt the activations of each feature. 129 130 Args: 131 tpu_embedding: TPUEmbedding, create dummy table variable to be used with 132 tpu_embedding. 133 134 Returns: 135 An OrderedDict mapping feature name to gradient. 136 137 Raises: 138 ValueError: if some gradients are not defined. 139 """ 140 g = ops.get_default_graph() 141 feature_to_gradient_dict = collections.OrderedDict() 142 for table_id, table in enumerate(tpu_embedding.table_to_config_dict): 143 table_gradients = g.get_collection( 144 'tpu_embedding_gradients_table_{}'.format(table_id)) 145 if any(gradient is None for gradient in table_gradients): 146 raise ValueError( 147 'Table {} with id {} has undefined gradients: this is probably ' 148 'because the model asked TPUEmbedding to compute activations that ' 149 'were not used.'.format(table, table_id)) 150 for feature, gradient in zip(tpu_embedding.table_to_features_dict[table], 151 table_gradients): 152 feature_to_gradient_dict[feature] = gradient 153 return feature_to_gradient_dict 154