1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for ops used with embeddings.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22import math 23 24import numpy as np 25from six.moves import xrange # pylint: disable=redefined-builtin 26 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import data_flow_ops 34from tensorflow.python.ops import embedding_ops 35from tensorflow.python.ops import gradient_checker 36from tensorflow.python.ops import init_ops 37from tensorflow.python.ops import linalg_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import partitioned_variables 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import test 44from tensorflow.python.platform import tf_logging 45from tensorflow.python.util import compat 46 47 48def _AsLong(array): 49 """Casts arrays elements to long type. Used to convert from numpy tf.""" 50 return [int(x) for x in array] 51 52 53class ScatterAddSubTest(test.TestCase): 54 55 def _TestCase(self, shape, indices, scatter_op=state_ops.scatter_add): 56 """Run a random test case with the given shape and indices. 57 58 Args: 59 shape: Shape of the parameters array. 60 indices: One-dimensional array of ints, the indices of the last dimension 61 of the parameters to update. 62 scatter_op: ScatterAdd or ScatterSub. 63 """ 64 super(ScatterAddSubTest, self).setUp() 65 with self.cached_session(use_gpu=False): 66 # Create a random parameter array of given shape 67 p_init = np.random.rand(*shape).astype("f") 68 # Create the shape of the update array. All dimensions except the last 69 # match the parameter array, the last dimension equals the # of indices. 70 vals_shape = [len(indices)] + shape[1:] 71 vals_init = np.random.rand(*vals_shape).astype("f") 72 v_i = [float(x) for x in vals_init.ravel()] 73 p = variables.Variable(p_init) 74 vals = constant_op.constant(v_i, shape=vals_shape, name="vals") 75 ind = constant_op.constant(indices, dtype=dtypes.int32) 76 p2 = scatter_op(p, ind, vals, name="updated_p") 77 # p = init 78 variables.global_variables_initializer().run() 79 # p += vals 80 result = self.evaluate(p2) 81 # Compute the expected 'p' using numpy operations. 82 for i, ind in enumerate(indices): 83 if scatter_op == state_ops.scatter_add: 84 p_init.reshape(shape[0], -1)[ind, :] += (vals_init.reshape( 85 vals_shape[0], -1)[i, :]) 86 else: 87 p_init.reshape(shape[0], -1)[ind, :] -= (vals_init.reshape( 88 vals_shape[0], -1)[i, :]) 89 self.assertTrue(all((p_init == result).ravel())) 90 91 @test_util.run_deprecated_v1 92 def testNoRepetitions(self): 93 self._TestCase([2, 2], [1]) 94 self._TestCase([4, 4, 4], [2, 0]) 95 self._TestCase([43, 20, 10, 10], [42, 5, 6, 1, 3, 5, 7, 9]) 96 97 @test_util.run_deprecated_v1 98 def testWithRepetitions(self): 99 self._TestCase([2, 2], [1, 1]) 100 self._TestCase([5, 3, 9, 5], [2, 0, 4, 1, 3, 1, 4, 0, 4, 3]) 101 self._TestCase([32, 4, 4], [31] * 8) 102 103 @test_util.run_deprecated_v1 104 def testRandom(self): 105 # Random shapes of rank 4, random indices 106 for _ in range(5): 107 shape = np.random.randint(1, 20, size=4) 108 indices = np.random.randint(shape[0], size=2 * shape[0]) 109 self._TestCase(_AsLong(list(shape)), list(indices)) 110 111 @test_util.run_deprecated_v1 112 def testSubRandom(self): 113 # Random shapes of rank 4, random indices 114 for _ in range(5): 115 shape = np.random.randint(1, 20, size=4) 116 indices = np.random.randint(shape[0], size=2 * shape[0]) 117 self._TestCase(_AsLong(list(shape)), list(indices), state_ops.scatter_sub) 118 119 @test_util.run_deprecated_v1 120 def testWrongShape(self): 121 # Indices and values mismatch. 122 var = variables.Variable( 123 array_ops.zeros(shape=[1024, 64, 64], dtype=dtypes.float32)) 124 indices = array_ops.placeholder(dtypes.int32, shape=[32]) 125 values = array_ops.placeholder(dtypes.float32, shape=[33, 64, 64]) 126 with self.assertRaises(ValueError): 127 state_ops.scatter_add(var, indices, values) 128 129 # Var and values mismatch. 130 values = array_ops.placeholder(dtypes.float32, shape=[32, 64, 63]) 131 with self.assertRaises(ValueError): 132 state_ops.scatter_add(var, indices, values) 133 134 135def _PName(param_id): 136 return "p" + str(param_id) 137 138 139def _EmbeddingParams(num_shards, 140 vocab_size, 141 dtype=dtypes.float32, 142 shape=None, 143 use_shapeless_placeholder=False): 144 p = [] 145 params = {} 146 feed_dict = {} 147 if not shape: 148 shape = [10] 149 for i in range(num_shards): 150 shard_shape = [vocab_size // num_shards] + shape 151 if i < vocab_size % num_shards: # Excess goes evenly on the first shards 152 shard_shape[0] += 1 153 154 param_name = _PName(i) 155 156 if use_shapeless_placeholder: 157 param = array_ops.placeholder(dtype, shape=None, name=param_name) 158 else: 159 param = constant_op.constant( 160 1.0, shape=shard_shape, dtype=dtype, name=param_name) 161 p.append(param) 162 np_type = "f" if dtype == dtypes.float32 else "d" 163 val = (np.random.rand(*shard_shape).astype(np_type)) + 1 164 params[param_name + ":0"] = val 165 feed_dict[param.name] = val 166 return p, params, feed_dict 167 168 169def _EmbeddingParamsAsPartitionedVariable(num_shards, 170 vocab_size, 171 dtype=dtypes.float32, 172 shape=None, 173 use_resource=False): 174 p, params, feed_dict = _EmbeddingParams( 175 num_shards, vocab_size, dtype=dtype, shape=shape) 176 shape = shape or [10] 177 partitioned_variable = variable_scope.get_variable( 178 "p", 179 shape=[vocab_size] + shape, 180 initializer=array_ops.concat([params[p_i.name] for p_i in p], 0), 181 partitioner=partitioned_variables.min_max_variable_partitioner( 182 max_partitions=num_shards, min_slice_size=1), 183 use_resource=use_resource) 184 return p, partitioned_variable, params, feed_dict 185 186 187def _EmbeddingResult(params, 188 id_vals, 189 num_shards, 190 vocab_size, 191 partition_strategy="mod", 192 weight_vals=None): 193 if weight_vals is None: 194 weight_vals = np.copy(id_vals) 195 weight_vals.fill(1) 196 values = [] 197 weights = [] 198 weights_squared = [] 199 for ids, wts in zip(id_vals, weight_vals): 200 value_aggregation = None 201 weight_aggregation = None 202 squared_weight_aggregation = None 203 if isinstance(ids, compat.integral_types): 204 ids = [ids] 205 wts = [wts] 206 for i, weight_value in zip(ids, wts): 207 if partition_strategy == "mod": 208 val = np.copy(params[_PName(i % num_shards) + ":0"][ 209 i // num_shards, :]) * weight_value 210 elif partition_strategy == "div": 211 ids_per_partition, extras = divmod(vocab_size, num_shards) 212 threshold = extras * (ids_per_partition + 1) 213 if i < threshold: 214 partition = i // (ids_per_partition + 1) 215 offset = i % (ids_per_partition + 1) 216 else: 217 partition = extras + (i - threshold) // ids_per_partition 218 offset = (i - threshold) % ids_per_partition 219 val = np.copy( 220 params[_PName(partition) + ":0"][offset, :]) * weight_value 221 else: 222 assert False 223 if value_aggregation is None: 224 assert weight_aggregation is None 225 assert squared_weight_aggregation is None 226 value_aggregation = val 227 weight_aggregation = weight_value 228 squared_weight_aggregation = weight_value * weight_value 229 else: 230 assert weight_aggregation is not None 231 assert squared_weight_aggregation is not None 232 value_aggregation += val 233 weight_aggregation += weight_value 234 squared_weight_aggregation += weight_value * weight_value 235 values.append(value_aggregation) 236 weights.append(weight_aggregation) 237 weights_squared.append(squared_weight_aggregation) 238 values = np.array(values).astype(np.float32) 239 weights = np.array(weights).astype(np.float32) 240 weights_squared = np.array(weights_squared).astype(np.float32) 241 return values, weights, weights_squared 242 243 244class EmbeddingLookupTest(test.TestCase): 245 246 # This test looks up [0, 0] in a parameter matrix sharded 2 ways. Since 247 # both the ids are in the first shard, one of the resulting lookup 248 # vector is going to be empty. The subsequent DivOp fails because of that. 249 # TODO(keveman): Disabling the test until the underlying problem is fixed. 250 @test_util.run_deprecated_v1 251 def testSimpleSharded(self): 252 with self.cached_session(): 253 num_shards = 2 254 vocab_size = 4 255 p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) 256 257 id_vals = np.array([0, 0]) 258 ids = constant_op.constant(list(id_vals), dtype=dtypes.int32) 259 print("Construct ids", ids.get_shape()) 260 embedding = embedding_ops.embedding_lookup(p, ids) 261 262 tf_result = embedding.eval(feed_dict=feed_dict) 263 np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) 264 self.assertAllEqual(np_result, tf_result) 265 self.assertShapeEqual(np_result, embedding) 266 267 @test_util.run_deprecated_v1 268 def testMaxNorm(self): 269 with self.cached_session(): 270 embeddings = constant_op.constant([[2.0]]) 271 272 ids = constant_op.constant([0], dtype=dtypes.int32) 273 embedding = embedding_ops.embedding_lookup( 274 [embeddings], ids, max_norm=1.0) 275 276 self.assertAllEqual(embedding.eval(), [[1.0]]) 277 278 @test_util.run_deprecated_v1 279 def testMaxNormNontrivial(self): 280 with self.cached_session(): 281 embeddings = constant_op.constant([[2.0, 4.0], [3.0, 1.0]]) 282 283 ids = constant_op.constant([0, 1], dtype=dtypes.int32) 284 embedding = embedding_ops.embedding_lookup( 285 [embeddings], ids, max_norm=2.0) 286 287 norms = math_ops.sqrt( 288 math_ops.reduce_sum(embeddings * embeddings, axis=1)) 289 normalized = embeddings / array_ops.stack([norms, norms], axis=1) 290 self.assertAllEqual(embedding.eval(), 2 * self.evaluate(normalized)) 291 292 @test_util.run_deprecated_v1 293 def testSimpleShardedPartitionedVariable(self): 294 with self.cached_session() as sess: 295 num_shards = 2 296 vocab_size = 4 297 p, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable( 298 num_shards, vocab_size) 299 300 id_vals = np.array([0, 0]) 301 ids = constant_op.constant(list(id_vals), dtype=dtypes.int32) 302 print("Construct ids", ids.get_shape()) 303 embedding = embedding_ops.embedding_lookup(p_variable, ids) 304 variables.global_variables_initializer().run() 305 params_values = [params[p_i.name] for p_i in p] 306 # Test that the PartitionedVariable components equal the list in p 307 p_var_val = self.evaluate(list(p_variable)) 308 # Actual test 309 tf_result = embedding.eval(feed_dict=feed_dict) 310 np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) 311 self.assertAllEqual(params_values, p_var_val) 312 self.assertAllEqual(np_result, tf_result) 313 self.assertShapeEqual(np_result, embedding) 314 315 @test_util.run_deprecated_v1 316 def testSimpleShardedPartitionedResourceVariable(self): 317 with self.cached_session() as sess: 318 num_shards = 2 319 vocab_size = 4 320 p, p_variable, params, _ = _EmbeddingParamsAsPartitionedVariable( 321 num_shards, vocab_size, use_resource=True) 322 323 id_vals = np.array([0, 0]) 324 ids = constant_op.constant(list(id_vals), dtype=dtypes.int32) 325 print("Construct ids", ids.get_shape()) 326 embedding = embedding_ops.embedding_lookup(p_variable, ids) 327 variables.global_variables_initializer().run() 328 params_values = [params[p_i.name] for p_i in p] 329 # Test that the PartitionedVariable components equal the list in p 330 p_var_val = self.evaluate(list(p_variable)) 331 # Actual test 332 print(ops.get_default_graph().as_graph_def()) 333 tf_result = self.evaluate(embedding) 334 np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) 335 self.assertAllEqual(params_values, p_var_val) 336 self.assertAllEqual(np_result, tf_result) 337 self.assertShapeEqual(np_result, embedding) 338 339 @test_util.run_deprecated_v1 340 def testShardedModPartitioningInt32Ids(self): 341 with self.cached_session(): 342 num_shards = 5 343 vocab_size = 13 344 # Embedding dimensions is 10. The vocab_size x 10 embedding 345 # parameters are spread in num_shards matrices, so the first 346 # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. 347 p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) 348 349 num_vals = 30 350 # Fetch num_vals embeddings for random word ids. Since 351 # num_vals > vocab_size, this ought to have repetitions, so 352 # will test that aspect. 353 id_vals = np.random.randint(vocab_size, size=num_vals) 354 ids = constant_op.constant(list(id_vals), dtype=dtypes.int32) 355 356 embedding = embedding_ops.embedding_lookup(p, ids) 357 tf_result = embedding.eval(feed_dict=feed_dict) 358 np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) 359 self.assertAllEqual(np_result, tf_result) 360 self.assertShapeEqual(np_result, embedding) 361 362 @test_util.run_deprecated_v1 363 def testShardedModPartitioningInt64Ids(self): 364 with self.cached_session(): 365 num_shards = 5 366 vocab_size = 13 367 # Embedding dimensions is 10. The vocab_size x 10 embedding 368 # parameters are spread in num_shards matrices, so the first 369 # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. 370 p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) 371 372 num_vals = 30 373 # Fetch num_vals embeddings for random word ids. Since 374 # num_vals > vocab_size, this ought to have repetitions, so 375 # will test that aspect. 376 id_vals = np.random.randint(vocab_size, size=num_vals) 377 ids = constant_op.constant(list(id_vals), dtype=dtypes.int64) 378 379 embedding = embedding_ops.embedding_lookup(p, ids) 380 tf_result = embedding.eval(feed_dict=feed_dict) 381 np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size) 382 self.assertAllEqual(np_result, tf_result) 383 self.assertShapeEqual(np_result, embedding) 384 385 @test_util.run_deprecated_v1 386 def testShardedDivPartitioningInt32Ids(self): 387 with self.cached_session(): 388 num_shards = 5 389 vocab_size = 13 390 # Embedding dimensions is 10. The vocab_size x 10 embedding 391 # parameters are spread in num_shards matrices, so the first 392 # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. 393 p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) 394 395 num_vals = 30 396 # Fetch num_vals embeddings for random word ids. Since 397 # num_vals > vocab_size, this ought to have repetitions, so 398 # will test that aspect. 399 id_vals = np.random.randint(vocab_size, size=num_vals) 400 ids = constant_op.constant(list(id_vals), dtype=dtypes.int32) 401 402 embedding = embedding_ops.embedding_lookup( 403 p, ids, partition_strategy="div") 404 tf_result = embedding.eval(feed_dict=feed_dict) 405 np_result, _, _ = _EmbeddingResult( 406 params, id_vals, num_shards, vocab_size, partition_strategy="div") 407 self.assertAllEqual(np_result, tf_result) 408 self.assertShapeEqual(np_result, embedding) 409 410 @test_util.run_deprecated_v1 411 def testShardedDivPartitioningInt32IdsPartitionedVariable(self): 412 with self.cached_session(): 413 num_shards = 5 414 vocab_size = 13 415 # Embedding dimensions is 10. The vocab_size x 10 embedding 416 # parameters are spread in num_shards matrices, so the first 417 # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. 418 _, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable( 419 num_shards, vocab_size) 420 421 num_vals = 30 422 # Fetch num_vals embeddings for random word ids. Since 423 # num_vals > vocab_size, this ought to have repetitions, so 424 # will test that aspect. 425 id_vals = np.random.randint(vocab_size, size=num_vals) 426 ids = constant_op.constant(list(id_vals), dtype=dtypes.int32) 427 variables.global_variables_initializer().run() 428 embedding = embedding_ops.embedding_lookup( 429 p_variable, ids, partition_strategy="div") 430 tf_result = embedding.eval(feed_dict=feed_dict) 431 np_result, _, _ = _EmbeddingResult( 432 params, id_vals, num_shards, vocab_size, partition_strategy="div") 433 self.assertAllEqual(np_result, tf_result) 434 self.assertShapeEqual(np_result, embedding) 435 436 @test_util.run_deprecated_v1 437 def testShardedDivPartitioningInt64Ids(self): 438 with self.cached_session(): 439 num_shards = 5 440 vocab_size = 13 441 # Embedding dimensions is 10. The vocab_size x 10 embedding 442 # parameters are spread in num_shards matrices, so the first 443 # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. 444 p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size) 445 446 num_vals = 30 447 # Fetch num_vals embeddings for random word ids. Since 448 # num_vals > vocab_size, this ought to have repetitions, so 449 # will test that aspect. 450 id_vals = np.random.randint(vocab_size, size=num_vals) 451 ids = constant_op.constant(list(id_vals), dtype=dtypes.int64) 452 453 embedding = embedding_ops.embedding_lookup( 454 p, ids, partition_strategy="div") 455 tf_result = embedding.eval(feed_dict=feed_dict) 456 np_result, _, _ = _EmbeddingResult( 457 params, id_vals, num_shards, vocab_size, partition_strategy="div") 458 self.assertAllEqual(np_result, tf_result) 459 self.assertShapeEqual(np_result, embedding) 460 461 @test_util.run_deprecated_v1 462 def testShardedDivPartitioningUnknownParamShape(self): 463 with self.cached_session(): 464 num_shards = 5 465 vocab_size = 13 466 # Embedding dimensions is 10. The vocab_size x 10 embedding 467 # parameters are spread in num_shards matrices, so the first 468 # 3 shards are 3 x 10 and the last 2 shards are 2 x 10. 469 470 # We clear parameter shapes, to test when shape is not statically known. 471 p, params, feed_dict = _EmbeddingParams( 472 num_shards, vocab_size, use_shapeless_placeholder=True) 473 474 num_vals = 30 475 # Fetch num_vals embeddings for random word ids. Since 476 # num_vals > vocab_size, this ought to have repetitions, so 477 # will test that aspect. 478 id_vals = np.random.randint(vocab_size, size=num_vals) 479 ids = constant_op.constant(list(id_vals), dtype=dtypes.int64) 480 481 embedding = embedding_ops.embedding_lookup( 482 p, ids, partition_strategy="div") 483 tf_result = embedding.eval(feed_dict=feed_dict) 484 np_result, _, _ = _EmbeddingResult( 485 params, id_vals, num_shards, vocab_size, partition_strategy="div") 486 self.assertAllEqual(np_result, tf_result) 487 488 @test_util.run_deprecated_v1 489 def testGradientsEmbeddingLookup(self): 490 vocab_size = 9 491 num_ids = 10 492 id_vals = list(np.random.randint(vocab_size, size=num_ids)) 493 tf_logging.vlog(1, id_vals) 494 for ids_shape in [(10,), (2, 5)]: 495 for num_shards in [1, 3]: 496 with self.cached_session(): 497 ids = constant_op.constant( 498 id_vals, shape=ids_shape, dtype=dtypes.int32) 499 x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2]) 500 y = embedding_ops.embedding_lookup(x, ids) 501 y_shape = ids_shape + tuple(params[_PName(0) + ":0"].shape[1:]) 502 x_name = [_PName(i) for i in range(num_shards)] 503 x_init_value = [params[x_n + ":0"] for x_n in x_name] 504 x_shape = [i.shape for i in x_init_value] 505 err = gradient_checker.compute_gradient_error( 506 x, x_shape, y, y_shape, x_init_value=x_init_value) 507 self.assertLess(err, 1e-4) 508 509 @test_util.run_deprecated_v1 510 def testGradientsEmbeddingLookupWithComputedParams(self): 511 vocab_size = 9 512 num_ids = 5 513 id_vals = list(np.random.randint(vocab_size, size=num_ids)) 514 tf_logging.vlog(1, id_vals) 515 for num_shards in [1, 3]: 516 with self.cached_session(): 517 ids = constant_op.constant(id_vals, dtype=dtypes.int32) 518 x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2]) 519 # This will force a conversion from IndexedSlices to Tensor. 520 x_squared = [math_ops.square(elem) for elem in x] 521 y = embedding_ops.embedding_lookup(x_squared, ids) 522 y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:]) 523 x_name = [_PName(i) for i in range(num_shards)] 524 x_init_value = [params[x_n + ":0"] for x_n in x_name] 525 x_shape = [i.shape for i in x_init_value] 526 err = gradient_checker.compute_gradient_error( 527 x, x_shape, y, y_shape, x_init_value=x_init_value) 528 self.assertLess(err, 1e-3) 529 530 def testConstructionNonSharded(self): 531 with ops.Graph().as_default(): 532 p = variables.Variable( 533 array_ops.zeros(shape=[100, 100], dtype=dtypes.float32)) 534 ids = constant_op.constant([0, 1, 1, 7], dtype=dtypes.int32) 535 embedding_ops.embedding_lookup([p], ids) 536 537 def testConstructionSharded(self): 538 with ops.Graph().as_default(): 539 p = [] 540 for _ in range(2): 541 p += [ 542 variables.Variable( 543 array_ops.zeros(shape=[100, 100], dtype=dtypes.float32)) 544 ] 545 ids = constant_op.constant([0, 1, 1, 17], dtype=dtypes.int32) 546 embedding_ops.embedding_lookup(p, ids) 547 548 @test_util.run_deprecated_v1 549 def testHigherRank(self): 550 np.random.seed(8) 551 with self.cached_session(): 552 for params_shape in (12,), (6, 3): 553 params = np.random.randn(*params_shape) 554 for ids_shape in (3, 2), (4, 3): 555 ids = np.random.randint( 556 params.shape[0], size=np.prod(ids_shape)).reshape(ids_shape) 557 # Compare nonsharded to gather 558 simple = embedding_ops.embedding_lookup(params, ids).eval() 559 self.assertAllEqual(simple, array_ops.gather(params, ids).eval()) 560 # Run a few random sharded versions 561 for procs in 1, 2, 3: 562 stride = procs * math_ops.range(params.shape[0] // procs) 563 split_params = [ 564 array_ops.gather(params, stride + p) for p in xrange(procs) 565 ] 566 sharded = embedding_ops.embedding_lookup(split_params, ids).eval() 567 self.assertAllEqual(simple, sharded) 568 569 @test_util.run_deprecated_v1 570 def testHigherRankMaxNorm(self): 571 np.random.seed(8) 572 with self.cached_session(): 573 for params_shape in (12,), (6, 3), (6, 2, 3): 574 # Test embedding rank 0, 1, 2. 575 # Note: the first dimension must be a common multiple of procs below. 576 params = 2 * np.ones(params_shape) 577 params_norm = params / np.sqrt( 578 np.sum( 579 params * params, tuple(range(params.ndim)[1:]), keepdims=True)) 580 for ids_shape in (), (3), (4, 3), (2, 3, 4): 581 ids = np.random.randint( 582 params.shape[0], size=np.prod(ids_shape, 583 dtype=np.int64)).reshape(ids_shape) 584 # Compare nonsharded to gather 585 simple = embedding_ops.embedding_lookup( 586 params, ids, max_norm=1.0).eval() 587 # assertAllClose is used here as different implementations of sqrt may 588 # be used to compute each of the values being compared. For example, 589 # on AVX512 builds the embedding operation makes use of Eigen's fast 590 # vectorized square root algorithm for doubles. These different 591 # implementations of sqrt are not guaranteed to produce exactly the 592 # same results. Therefore, an exact comparison cannot be made. 593 self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval()) 594 # Run a few different sharded versions. 595 for procs in 1, 2, 3: 596 stride = procs * math_ops.range(params.shape[0] // procs) 597 split_params = [ 598 array_ops.gather(params, stride + p) for p in xrange(procs) 599 ] 600 sharded = embedding_ops.embedding_lookup( 601 split_params, ids, max_norm=1.0).eval() 602 self.assertAllEqual(simple, sharded) 603 604 @test_util.run_deprecated_v1 605 def testTransform(self): 606 # This tests all combinations of: 607 # - ids rank 0, 1, >1 608 # - params sharded/unsharded 609 # It always applies max_norm. 610 np.random.seed(8) 611 l2_norm = 2. 612 with self.cached_session(): 613 # Param values are in [l2_norm, l2_norm+1) so it will always clip. 614 params = np.random.rand(6, 3) + l2_norm 615 params_norm = l2_norm * params / np.sqrt( 616 np.sum(params * params, axis=1, keepdims=True)) 617 # Compute the norm of each embedding. This will change the embedding 618 # rank to 0. 619 params_norm = np.linalg.norm(params_norm, axis=1) 620 transform = lambda x: linalg_ops.norm(x, axis=1) 621 for ids_shape in (), (3), (4, 3), (2, 3, 4): 622 # Test ids rank 0, 1, 2, 3. 623 ids = np.random.randint( 624 params.shape[0], size=np.prod(ids_shape, 625 dtype=np.int64)).reshape(ids_shape) 626 # Compare nonsharded to gather. 627 simple = embedding_ops._embedding_lookup_and_transform( 628 params, ids, max_norm=l2_norm, transform_fn=transform).eval() 629 self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval()) 630 # Run a few different sharded versions. 631 for procs in 1, 2, 3: 632 stride = procs * math_ops.range(params.shape[0] // procs) 633 split_params = [ 634 array_ops.gather(params, stride + p) for p in xrange(procs) 635 ] 636 sharded = embedding_ops._embedding_lookup_and_transform( 637 split_params, ids, max_norm=l2_norm, 638 transform_fn=transform).eval() 639 # assertAllClose is used here as different implementations of sqrt may 640 # be used to compute each of the values being compared. For example, 641 # on AVX512 builds the embedding operation makes use of Eigen's fast 642 # vectorized square root algorithm for doubles. These different 643 # implementations of sqrt are not guaranteed to produce exactly the 644 # same results. Therefore, an exact comparison cannot be made. 645 self.assertAllClose(simple, sharded) 646 647 648class EmbeddingLookupSparseTest(test.TestCase): 649 650 def _RandomIdsAndWeights(self, batch_size, vocab_size): 651 max_val_per_entry = 6 652 vals_per_batch_entry = np.random.randint( 653 1, max_val_per_entry, size=batch_size) 654 num_vals = np.sum(vals_per_batch_entry) 655 656 ids = np.random.randint(vocab_size, size=num_vals) 657 weights = 1 + np.random.rand(num_vals) 658 659 indices = [] 660 for batch_entry, num_val in enumerate(vals_per_batch_entry): 661 for val_index in range(num_val): 662 indices.append([batch_entry, val_index]) 663 664 shape = [batch_size, max_val_per_entry] 665 666 sp_ids = sparse_tensor.SparseTensor( 667 constant_op.constant(indices, dtypes.int64), 668 constant_op.constant(ids, dtypes.int32), 669 constant_op.constant(shape, dtypes.int64)) 670 sp_weights = sparse_tensor.SparseTensor( 671 constant_op.constant(indices, dtypes.int64), 672 constant_op.constant(weights, dtypes.float32), 673 constant_op.constant(shape, dtypes.int64)) 674 675 return sp_ids, sp_weights, ids, weights, vals_per_batch_entry 676 677 def _GroupByBatchEntry(self, vals, vals_per_batch_entry): 678 grouped_vals = [] 679 index = 0 680 for num_val in vals_per_batch_entry: 681 grouped_vals.append(list(vals[index:(index + num_val)])) 682 index += num_val 683 return grouped_vals 684 685 @test_util.run_deprecated_v1 686 def testEmbeddingLookupSparse(self): 687 vocab_size = 13 688 batch_size = 10 689 param_shape = [2, 5] 690 expected_lookup_result_shape = [None] + param_shape 691 692 sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( 693 self._RandomIdsAndWeights(batch_size, vocab_size)) 694 695 grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry) 696 grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry) 697 grouped_ignored_weights = self._GroupByBatchEntry( 698 np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry) 699 700 for num_shards, combiner, dtype, ignore_weights in itertools.product( 701 [1, 5], ["sum", "mean", "sqrtn"], 702 [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64], 703 [True, False]): 704 705 with self.cached_session(): 706 p, params, feed_dict = _EmbeddingParams( 707 num_shards, vocab_size, shape=param_shape, dtype=dtype) 708 embedding_sum = embedding_ops.embedding_lookup_sparse( 709 p, 710 sp_ids, 711 None if ignore_weights else sp_weights, 712 combiner=combiner) 713 714 self.assertEqual(embedding_sum.get_shape().as_list(), 715 expected_lookup_result_shape) 716 if dtype in (dtypes.float16, dtypes.bfloat16): 717 self.assertEqual(embedding_sum.dtype, dtypes.float32) 718 else: 719 self.assertEqual(embedding_sum.dtype, dtype) 720 721 tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) 722 723 np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult( 724 params, 725 grouped_ids, 726 num_shards, 727 vocab_size, 728 weight_vals=grouped_ignored_weights 729 if ignore_weights else grouped_weights) 730 if combiner == "mean": 731 np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1)) 732 if combiner == "sqrtn": 733 np_embedding_sum /= np.reshape( 734 np.sqrt(np_weight_sq_sum), (batch_size, 1, 1)) 735 736 rtol = 1e-6 737 if dtype == dtypes.bfloat16: 738 rtol = 1e-2 739 elif dtype == dtypes.float16: 740 rtol = 1e-3 741 atol = rtol 742 self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol) 743 744 @test_util.run_deprecated_v1 745 def testGradientsEmbeddingLookupSparse(self): 746 vocab_size = 12 747 batch_size = 4 748 param_shape = [2, 3] 749 sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights( 750 batch_size, vocab_size)) 751 752 for num_shards, combiner, dtype, ignore_weights in itertools.product( 753 [1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32, 754 dtypes.float64], [True, False]): 755 with self.cached_session(): 756 x, params, _ = _EmbeddingParams( 757 num_shards, vocab_size, shape=param_shape, dtype=dtype) 758 759 y = embedding_ops.embedding_lookup_sparse( 760 x, 761 sp_ids, 762 None if ignore_weights else sp_weights, 763 combiner=combiner) 764 x_name = [_PName(i) for i in range(num_shards)] 765 x_init_value = [params[x_n + ":0"] for x_n in x_name] 766 x_shape = [i.shape for i in x_init_value] 767 y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:]) 768 err = gradient_checker.compute_gradient_error( 769 x, x_shape, y, y_shape, x_init_value=x_init_value) 770 self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3) 771 772 @test_util.run_deprecated_v1 773 def testIncompatibleShapes(self): 774 with self.cached_session(): 775 x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32) 776 sp_ids = sparse_tensor.SparseTensor( 777 constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64), 778 constant_op.constant([0, 1, 2], dtypes.int32), 779 constant_op.constant([2, 2], dtypes.int64)) 780 sp_weights = sparse_tensor.SparseTensor( 781 constant_op.constant([[0, 0], [0, 1]], dtypes.int64), 782 constant_op.constant([12.0, 5.0], dtypes.float32), 783 constant_op.constant([1, 2], dtypes.int64)) 784 785 with self.assertRaises(ValueError): 786 embedding_ops.embedding_lookup_sparse( 787 x, sp_ids, sp_weights, combiner="mean") 788 789 790class SafeEmbeddingLookupSparseTest(test.TestCase): 791 792 def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1): 793 assert vocab_size > 0 794 assert embed_dim > 0 795 assert num_shards > 0 796 assert num_shards <= vocab_size 797 798 initializer = init_ops.truncated_normal_initializer( 799 mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32) 800 embedding_weights = list(variable_scope.get_variable( 801 name="embedding_weights", 802 shape=[vocab_size, embed_dim], 803 partitioner=partitioned_variables.fixed_size_partitioner(num_shards), 804 initializer=initializer)) 805 for w in embedding_weights: 806 w.initializer.run() 807 embedding_weights = [w.eval() for w in embedding_weights] 808 return embedding_weights 809 810 def _ids_and_weights_2d(self): 811 # Each row demonstrates a test case: 812 # Row 0: multiple valid ids, 1 invalid id, weighted mean 813 # Row 1: all ids are invalid (leaving no valid ids after pruning) 814 # Row 2: no ids to begin with 815 # Row 3: single id 816 # Row 4: all ids have <=0 weight 817 indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]] 818 ids = [0, 1, -1, -1, 2, 0, 1] 819 weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] 820 shape = [5, 4] 821 822 sparse_ids = sparse_tensor.SparseTensor( 823 constant_op.constant(indices, dtypes.int64), 824 constant_op.constant(ids, dtypes.int64), 825 constant_op.constant(shape, dtypes.int64)) 826 827 sparse_weights = sparse_tensor.SparseTensor( 828 constant_op.constant(indices, dtypes.int64), 829 constant_op.constant(weights, dtypes.float32), 830 constant_op.constant(shape, dtypes.int64)) 831 832 return sparse_ids, sparse_weights 833 834 def _ids_and_weights_3d(self): 835 # Each (2-D) index demonstrates a test case: 836 # Index 0, 0: multiple valid ids, 1 invalid id, weighted mean 837 # Index 0, 1: all ids are invalid (leaving no valid ids after pruning) 838 # Index 0, 2: no ids to begin with 839 # Index 1, 0: single id 840 # Index 1, 1: all ids have <=0 weight 841 # Index 1, 2: no ids to begin with 842 indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0], 843 [1, 1, 1]] 844 ids = [0, 1, -1, -1, 2, 0, 1] 845 weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5] 846 shape = [2, 3, 4] 847 848 sparse_ids = sparse_tensor.SparseTensor( 849 constant_op.constant(indices, dtypes.int64), 850 constant_op.constant(ids, dtypes.int64), 851 constant_op.constant(shape, dtypes.int64)) 852 853 sparse_weights = sparse_tensor.SparseTensor( 854 constant_op.constant(indices, dtypes.int64), 855 constant_op.constant(weights, dtypes.float32), 856 constant_op.constant(shape, dtypes.int64)) 857 858 return sparse_ids, sparse_weights 859 860 @test_util.run_deprecated_v1 861 def test_safe_embedding_lookup_sparse_return_zero_vector(self): 862 with self.cached_session(): 863 embedding_weights = self._random_weights() 864 sparse_ids, sparse_weights = self._ids_and_weights_2d() 865 866 embedding_lookup_result = ( 867 embedding_ops.safe_embedding_lookup_sparse_v2( 868 embedding_weights, sparse_ids, sparse_weights).eval()) 869 870 self.assertAllClose( 871 embedding_lookup_result, 872 [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 873 3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4]) 874 875 @test_util.run_deprecated_v1 876 def test_safe_embedding_lookup_sparse_return_special_vector(self): 877 with self.cached_session(): 878 embedding_weights = self._random_weights() 879 sparse_ids, sparse_weights = self._ids_and_weights_2d() 880 881 embedding_lookup_result = ( 882 embedding_ops.safe_embedding_lookup_sparse_v2( 883 embedding_weights, sparse_ids, sparse_weights, 884 default_id=3).eval()) 885 886 self.assertAllClose( 887 embedding_lookup_result, 888 [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 889 3.0, embedding_weights[0][3], embedding_weights[0][3], 890 embedding_weights[0][2], embedding_weights[0][3]]) 891 892 @test_util.run_deprecated_v1 893 def test_safe_embedding_lookup_sparse_no_weights(self): 894 with self.cached_session(): 895 embedding_weights = self._random_weights() 896 sparse_ids, _ = self._ids_and_weights_2d() 897 898 embedding_lookup_result = ( 899 embedding_ops.safe_embedding_lookup_sparse_v2( 900 embedding_weights, sparse_ids, None).eval()) 901 902 self.assertAllClose( 903 embedding_lookup_result, 904 [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, 905 [0] * 4, embedding_weights[0][2], ( 906 embedding_weights[0][0] + embedding_weights[0][1]) / 2.0]) 907 908 @test_util.run_deprecated_v1 909 def test_safe_embedding_lookup_sparse_partitioned(self): 910 with self.cached_session(): 911 embedding_weights = self._random_weights(num_shards=3) 912 sparse_ids, _ = self._ids_and_weights_2d() 913 914 embedding_lookup_result = ( 915 embedding_ops.safe_embedding_lookup_sparse_v2( 916 embedding_weights, sparse_ids, None).eval()) 917 918 embedding_weights = list(itertools.chain(*embedding_weights)) 919 self.assertAllClose(embedding_lookup_result, 920 [(embedding_weights[0] + embedding_weights[1]) / 2.0, 921 [0] * 4, [0] * 4, embedding_weights[2], 922 (embedding_weights[0] + embedding_weights[1]) / 2.0]) 923 924 @test_util.run_deprecated_v1 925 def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self): 926 with self.cached_session(): 927 embedding_weights = self._random_weights(num_shards=3) 928 sparse_ids, sparse_weights = self._ids_and_weights_2d() 929 930 embedding_weights[1] = embedding_weights[1].astype(np.float64) 931 self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse, 932 embedding_weights, sparse_ids) 933 embedding_weights = [ 934 constant_op.constant(w, dtype=dtypes.float64) 935 for w in embedding_weights 936 ] 937 self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, 938 embedding_weights, sparse_ids, sparse_weights) 939 940 @test_util.run_deprecated_v1 941 def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self): 942 with self.cached_session(): 943 embedding_weights = self._random_weights() 944 sparse_ids, sparse_weights = self._ids_and_weights_3d() 945 946 embedding_lookup_result = ( 947 embedding_ops.safe_embedding_lookup_sparse_v2( 948 embedding_weights, sparse_ids, sparse_weights).eval()) 949 950 self.assertAllClose(embedding_lookup_result, [[ 951 (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0, 952 [0] * 4, [0] * 4 953 ], [embedding_weights[0][2], [0] * 4, [0] * 4]]) 954 955 @test_util.run_deprecated_v1 956 def test_safe_embedding_lookup_sparse_3d_return_special_vector(self): 957 with self.cached_session(): 958 embedding_weights = self._random_weights() 959 sparse_ids, sparse_weights = self._ids_and_weights_3d() 960 961 embedding_lookup_result = ( 962 embedding_ops.safe_embedding_lookup_sparse_v2( 963 embedding_weights, sparse_ids, sparse_weights, 964 default_id=3).eval()) 965 966 self.assertAllClose( 967 embedding_lookup_result, 968 [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 969 3.0, embedding_weights[0][3], embedding_weights[0][3]], [ 970 embedding_weights[0][2], embedding_weights[0][3], 971 embedding_weights[0][3] 972 ]]) 973 974 @test_util.run_deprecated_v1 975 def test_safe_embedding_lookup_sparse_3d_no_weights(self): 976 with self.cached_session(): 977 embedding_weights = self._random_weights() 978 sparse_ids, _ = self._ids_and_weights_3d() 979 980 embedding_lookup_result = ( 981 embedding_ops.safe_embedding_lookup_sparse_v2( 982 embedding_weights, sparse_ids, None).eval()) 983 984 self.assertAllClose(embedding_lookup_result, [[( 985 embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [ 986 0 987 ] * 4], [ 988 embedding_weights[0][2], 989 (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4 990 ]]) 991 992 @test_util.run_deprecated_v1 993 def test_safe_embedding_lookup_sparse_3d_partitioned(self): 994 with self.cached_session(): 995 embedding_weights = self._random_weights(num_shards=3) 996 sparse_ids, _ = self._ids_and_weights_3d() 997 998 embedding_lookup_result = ( 999 embedding_ops.safe_embedding_lookup_sparse_v2( 1000 embedding_weights, sparse_ids, None).eval()) 1001 1002 embedding_weights = list(itertools.chain(*embedding_weights)) 1003 self.assertAllClose(embedding_lookup_result, [[ 1004 (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4 1005 ], [ 1006 embedding_weights[2], 1007 (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4 1008 ]]) 1009 1010 @test_util.run_deprecated_v1 1011 def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights( 1012 self): 1013 with self.cached_session(): 1014 embedding_weights = self._random_weights(num_shards=3) 1015 sparse_ids, sparse_weights = self._ids_and_weights_3d() 1016 1017 embedding_weights[1] = embedding_weights[1].astype(np.float64) 1018 self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse, 1019 embedding_weights, sparse_ids) 1020 embedding_weights = [ 1021 constant_op.constant(w, dtype=dtypes.float64) 1022 for w in embedding_weights 1023 ] 1024 self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse, 1025 embedding_weights, sparse_ids, sparse_weights) 1026 1027 1028class DynamicStitchOpTest(test.TestCase): 1029 1030 @test_util.run_deprecated_v1 1031 def testCint32Cpu(self): 1032 with self.session(use_gpu=False): 1033 indices = [ 1034 ops.convert_to_tensor([0, 1, 2]), 1035 ops.convert_to_tensor([2, 3]) 1036 ] 1037 values = [ 1038 ops.convert_to_tensor([12, 23, 34]), 1039 ops.convert_to_tensor([1, 2]) 1040 ] 1041 self.assertAllEqual( 1042 data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2]) 1043 1044 @test_util.run_deprecated_v1 1045 def testCint32Gpu(self): 1046 with self.session(use_gpu=True): 1047 indices = [ 1048 ops.convert_to_tensor([0, 1, 2]), 1049 ops.convert_to_tensor([2, 3]) 1050 ] 1051 values = [ 1052 ops.convert_to_tensor([12, 23, 34]), 1053 ops.convert_to_tensor([1, 2]) 1054 ] 1055 self.assertAllEqual( 1056 data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2]) 1057 1058 @test_util.run_deprecated_v1 1059 def testInt32Cpu(self): 1060 with self.session(use_gpu=False): 1061 indices = [ 1062 ops.convert_to_tensor([0, 1, 2]), 1063 ops.convert_to_tensor([2, 3]) 1064 ] 1065 values = [ 1066 ops.convert_to_tensor([12, 23, 34]), 1067 ops.convert_to_tensor([1, 2]) 1068 ] 1069 self.assertAllEqual( 1070 data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2]) 1071 1072 @test_util.run_deprecated_v1 1073 def testInt32Gpu(self): 1074 with self.session(use_gpu=True): 1075 indices = [ 1076 ops.convert_to_tensor([0, 1, 2]), 1077 ops.convert_to_tensor([2, 3]) 1078 ] 1079 values = [ 1080 ops.convert_to_tensor([12, 23, 34]), 1081 ops.convert_to_tensor([1, 2]) 1082 ] 1083 self.assertAllEqual( 1084 data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2]) 1085 1086 @test_util.run_deprecated_v1 1087 def testSumGradArgs(self): 1088 with self.session(use_gpu=False): 1089 indices = [ 1090 ops.convert_to_tensor([0, 1, 2, 3]), 1091 ops.convert_to_tensor([2, 3]) 1092 ] 1093 values = [ 1094 ops.convert_to_tensor([2, 3, 5, 7]), 1095 ops.convert_to_tensor([1, 1]) 1096 ] 1097 self.assertAllEqual( 1098 data_flow_ops.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1]) 1099 1100 # We expect that the values are merged in order. 1101 @test_util.run_deprecated_v1 1102 def testStitchOrder(self): 1103 with self.cached_session(): 1104 indices = [] 1105 np_values = [] 1106 values = [] 1107 for _ in range(10): 1108 indices.extend([ops.convert_to_tensor(np.arange(100).astype(np.int32))]) 1109 np_values.extend([np.random.uniform(size=100)]) 1110 values.extend([ops.convert_to_tensor(np_values[-1])]) 1111 stitched = data_flow_ops.dynamic_stitch(indices, values).eval() 1112 self.assertAllEqual(np_values[-1], stitched) 1113 1114 1115class ParallelDynamicStitchOpTest(test.TestCase): 1116 1117 @test_util.run_deprecated_v1 1118 def testCint32Cpu(self): 1119 with self.session(use_gpu=False): 1120 indices = [ 1121 ops.convert_to_tensor([0, 1, 4, 6]), 1122 ops.convert_to_tensor([2, 3, 5]) 1123 ] 1124 values = [ 1125 ops.convert_to_tensor([12, 23, 34, 45]), 1126 ops.convert_to_tensor([1, 2, 3]) 1127 ] 1128 self.assertAllEqual( 1129 data_flow_ops.parallel_dynamic_stitch(indices, values).eval(), 1130 [12, 23, 1, 2, 34, 3, 45]) 1131 1132 @test_util.run_deprecated_v1 1133 def testInt32Cpu(self): 1134 with self.session(use_gpu=False): 1135 indices = [ 1136 ops.convert_to_tensor([0, 1, 5, 6, 7]), 1137 ops.convert_to_tensor([2, 4, 3]) 1138 ] 1139 values = [ 1140 ops.convert_to_tensor([12, 23, 34, 45, 56]), 1141 ops.convert_to_tensor([1, 3, 2]) 1142 ] 1143 self.assertAllEqual( 1144 data_flow_ops.parallel_dynamic_stitch(indices, values).eval(), 1145 [12, 23, 1, 2, 3, 34, 45, 56]) 1146 1147 @test_util.run_deprecated_v1 1148 def testSimple(self): 1149 with self.session(use_gpu=False): 1150 indices = [ops.convert_to_tensor([0, 1]), ops.convert_to_tensor([2, 3])] 1151 values = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor([1, 1])] 1152 self.assertAllEqual( 1153 data_flow_ops.parallel_dynamic_stitch(indices, values).eval(), 1154 [2, 3, 1, 1]) 1155 1156 1157if __name__ == "__main__": 1158 test.main() 1159