1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for sampling_ops.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.nn.python.ops import sampling_ops 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import nn 26from tensorflow.python.platform import test 27 28 29class RankSampledSoftmaxLossTest(test.TestCase): 30 31 def setUp(self): 32 self._sampled = [3, 4, 5, 6, 7] 33 self._num_sampled = len(self._sampled) 34 # Because values of all matrices increase with indices, logits increase with 35 # class id. So, for the above sampled classes, adaptive sampling will select 36 # these resampled classes. 37 self._resampled = [5, 6, 7] 38 self._num_resampled = len(self._resampled) 39 self._num_classes = 10 40 self._num_true = 2 41 self._sampled_values = (self._sampled, [[0.5], [0.5]], 42 [0.5, 0.5, 0.5, 0.5, 0.5]) 43 self._resampled_values = (self._resampled, [[0.5], [0.5]], [0.5, 0.5, 0.5]) 44 self._remove_accidental_hits = False 45 self._embed_dim = 5 46 self._batch_size = 2 47 48 def _weights(self): 49 return constant_op.constant([ 50 [0.0, 0.1, 0.2, 0.3, 0.4], 51 [1.0, 1.1, 1.2, 1.3, 1.4], 52 [2.0, 2.1, 2.2, 2.3, 2.4], 53 [3.0, 3.1, 3.2, 3.3, 3.4], 54 [4.0, 4.1, 4.2, 4.3, 4.4], 55 [5.0, 5.1, 5.2, 5.3, 5.4], 56 [6.0, 6.1, 6.2, 6.3, 6.4], 57 [7.0, 7.1, 7.2, 7.3, 7.4], 58 [8.0, 8.1, 8.2, 8.3, 8.4], 59 [9.0, 9.1, 9.2, 9.3, 9.4], 60 ]) 61 62 def _div_sharded_weights(self): 63 return [ 64 constant_op.constant([ 65 [0.0, 0.1, 0.2, 0.3, 0.4], 66 [1.0, 1.1, 1.2, 1.3, 1.4], 67 ]), 68 constant_op.constant([ 69 [2.0, 2.1, 2.2, 2.3, 2.4], 70 [3.0, 3.1, 3.2, 3.3, 3.4], 71 ]), 72 constant_op.constant([ 73 [4.0, 4.1, 4.2, 4.3, 4.4], 74 [5.0, 5.1, 5.2, 5.3, 5.4], 75 ]), 76 constant_op.constant([ 77 [6.0, 6.1, 6.2, 6.3, 6.4], 78 [7.0, 7.1, 7.2, 7.3, 7.4], 79 ]), 80 constant_op.constant([ 81 [8.0, 8.1, 8.2, 8.3, 8.4], 82 [9.0, 9.1, 9.2, 9.3, 9.4], 83 ]), 84 ] 85 86 def _mod_sharded_weights(self): 87 return [ 88 constant_op.constant([ 89 [0.0, 0.1, 0.2, 0.3, 0.4], 90 [5.0, 5.1, 5.2, 5.3, 5.4], 91 ]), 92 constant_op.constant([ 93 [1.0, 1.1, 1.2, 1.3, 1.4], 94 [6.0, 6.1, 6.2, 6.3, 6.4], 95 ]), 96 constant_op.constant([ 97 [2.0, 2.1, 2.2, 2.3, 2.4], 98 [7.0, 7.1, 7.2, 7.3, 7.4], 99 ]), 100 constant_op.constant([ 101 [3.0, 3.1, 3.2, 3.3, 3.4], 102 [8.0, 8.1, 8.2, 8.3, 8.4], 103 ]), 104 constant_op.constant([ 105 [4.0, 4.1, 4.2, 4.3, 4.4], 106 [9.0, 9.1, 9.2, 9.3, 9.4], 107 ]), 108 ] 109 110 def _biases(self): 111 return constant_op.constant( 112 [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) 113 114 def _div_sharded_biases(self): 115 return [ 116 constant_op.constant([0.0, 0.1]), 117 constant_op.constant([0.2, 0.3]), 118 constant_op.constant([0.4, 0.5]), 119 constant_op.constant([0.6, 0.7]), 120 constant_op.constant([0.8, 0.9]), 121 ] 122 123 def _mod_sharded_biases(self): 124 return [ 125 constant_op.constant([0.0, 0.5]), 126 constant_op.constant([0.1, 0.6]), 127 constant_op.constant([0.2, 0.7]), 128 constant_op.constant([0.3, 0.8]), 129 constant_op.constant([0.4, 0.9]), 130 ] 131 132 def _labels(self): 133 return constant_op.constant( 134 [[0, 1], [1, 2]], 135 shape=(self._batch_size, self._num_true), 136 name='labels', 137 dtype=dtypes.int64) 138 139 def _inputs(self): 140 return constant_op.constant( 141 [ 142 [0., 1., 2., 3., 4.], 143 [10., 11., 12., 13., 14.], 144 ], 145 shape=(self._batch_size, self._embed_dim), 146 name='inputs') 147 148 def testInvalidNumSampled0(self): 149 with ops.Graph().as_default(): 150 with self.assertRaisesRegexp( 151 ValueError, 152 r'num_resampled \(3\) must be less than num_sampled \(3\)'): 153 sampling_ops.rank_sampled_softmax_loss( 154 weights=self._weights(), 155 biases=self._biases(), 156 labels=self._labels(), 157 inputs=self._inputs(), 158 num_sampled=3, 159 num_resampled=3, 160 num_classes=self._num_classes, 161 num_true=self._num_true, 162 sampled_values=None, 163 resampling_temperature=1., 164 remove_accidental_hits=True, 165 partition_strategy='div') 166 167 def testInvalidNumSampled1(self): 168 with ops.Graph().as_default(): 169 with self.assertRaisesRegexp( 170 ValueError, 171 r'num_resampled \(3\) must be less than num_sampled \(2\)'): 172 sampling_ops.rank_sampled_softmax_loss( 173 weights=self._weights(), 174 biases=self._biases(), 175 labels=self._labels(), 176 inputs=self._inputs(), 177 num_sampled=2, 178 num_resampled=3, 179 num_classes=self._num_classes, 180 num_true=self._num_true, 181 sampled_values=None, 182 resampling_temperature=1., 183 remove_accidental_hits=True, 184 partition_strategy='div') 185 186 def testMissingPartitionStrategy(self): 187 with ops.Graph().as_default(): 188 with self.assertRaisesRegexp(ValueError, 189 r'unsupported partition_strategy \(None\)'): 190 sampling_ops.rank_sampled_softmax_loss( 191 weights=self._weights(), 192 biases=self._biases(), 193 labels=self._labels(), 194 inputs=self._inputs(), 195 num_sampled=2, 196 num_resampled=1, 197 num_classes=self._num_classes, 198 num_true=self._num_true, 199 sampled_values=None, 200 resampling_temperature=1., 201 remove_accidental_hits=True, 202 partition_strategy=None) 203 204 def _testCompareWithNN(self, weights, biases, partition_strategy): 205 with ops.Graph().as_default(): 206 loss = sampling_ops.rank_sampled_softmax_loss( 207 weights=weights(), 208 biases=biases(), 209 labels=self._labels(), 210 inputs=self._inputs(), 211 num_sampled=self._num_sampled, 212 num_resampled=self._num_resampled, 213 num_classes=self._num_classes, 214 num_true=self._num_true, 215 sampled_values=self._sampled_values, 216 resampling_temperature=1., 217 remove_accidental_hits=self._remove_accidental_hits, 218 partition_strategy=partition_strategy) 219 loss_nn = nn.sampled_softmax_loss( 220 weights=weights(), 221 biases=biases(), 222 labels=self._labels(), 223 inputs=self._inputs(), 224 num_sampled=self._num_resampled, 225 num_classes=self._num_classes, 226 num_true=self._num_true, 227 sampled_values=self._resampled_values, 228 remove_accidental_hits=self._remove_accidental_hits, 229 partition_strategy=partition_strategy) 230 with self.cached_session() as sess: 231 loss_val = sess.run(loss) 232 loss_nn_val = sess.run(loss_nn) 233 234 self.assertAllClose(loss_val, loss_nn_val) 235 236 def testCompareWithNNUnsharded(self): 237 self._testCompareWithNN(self._weights, self._biases, 'div') 238 239 def testCompareWithNNShardWeightsDiv(self): 240 self._testCompareWithNN(self._div_sharded_weights, self._biases, 'div') 241 242 def testCompareWithNNShardWeightsAndBiasesDiv(self): 243 self._testCompareWithNN(self._div_sharded_weights, self._div_sharded_biases, 244 'div') 245 246 def testCompareWithNNShardWeightsMod(self): 247 self._testCompareWithNN(self._mod_sharded_weights, self._biases, 'mod') 248 249 def testCompareWithNNShardWeightsAndBiasesMod(self): 250 self._testCompareWithNN(self._mod_sharded_weights, self._mod_sharded_biases, 251 'mod') 252 253 def _testCompareWithNNTemperature(self, temperature, resampled): 254 weights = [[1., 2.], [3., 4.]] # two sampled classes 255 inputs = [[6., -5. / 2.], [-11., 21. / 2.]] 256 # Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity) 257 # Let x0, x1 = inputs 258 # logits: 259 # w0.x0 = 1 260 # w0.x1 = 10 261 # w1.x0 = 8 262 # w1.x1 = 9 263 # Resampling 1 class with temperature = t will pick the larger of: 264 # exp(1/t) + exp(10/t) ==> w0, for values of t < 2.12 265 # exp(8/t) + exp(9/t) ==> w1, for values of t > 2.13 266 num_sampled = 2 267 num_resampled = 1 268 num_classes = 2 269 num_true = 1 270 sampled_values = [0, 1], [[1.], [1.]], [1., 1.] 271 resampled_values = [resampled], [[1.], [1.]], [1.] 272 remove_accidental_hits = False 273 with ops.Graph().as_default(): 274 weights = constant_op.constant(weights) 275 biases = constant_op.constant([0., 0.]) 276 labels = constant_op.constant([[0], [1]], dtype=dtypes.int64) 277 inputs = constant_op.constant(inputs) 278 loss = sampling_ops.rank_sampled_softmax_loss( 279 weights=weights, 280 biases=biases, 281 labels=labels, 282 inputs=inputs, 283 num_sampled=num_sampled, 284 num_resampled=num_resampled, 285 num_classes=num_classes, 286 num_true=num_true, 287 sampled_values=sampled_values, 288 resampling_temperature=constant_op.constant(temperature), 289 remove_accidental_hits=remove_accidental_hits, 290 partition_strategy='div') 291 loss_nn = nn.sampled_softmax_loss( 292 weights=weights, 293 biases=biases, 294 labels=labels, 295 inputs=inputs, 296 num_sampled=num_resampled, 297 num_classes=num_classes, 298 num_true=num_true, 299 sampled_values=resampled_values, 300 remove_accidental_hits=remove_accidental_hits, 301 partition_strategy='div') 302 with self.cached_session() as sess: 303 loss_val = sess.run(loss) 304 loss_nn_val = sess.run(loss_nn) 305 306 self.assertAllClose(loss_val, loss_nn_val) 307 308 def testCompareWithNNTemperatureLo1(self): 309 self._testCompareWithNNTemperature(1., 0) 310 311 def testCompareWithNNTemperatureLo2(self): 312 self._testCompareWithNNTemperature(2.12, 0) 313 314 def testCompareWithNNTemperatureHi1(self): 315 self._testCompareWithNNTemperature(2.13, 1) 316 317 def testCompareWithNNTemperatureHi2(self): 318 self._testCompareWithNNTemperature(3., 1) 319 320 321if __name__ == '__main__': 322 test.main() 323