1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for third_party.tensorflow.python.ops.ragged_tensor.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.eager import backprop 27from tensorflow.python.eager import context 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_shape 33from tensorflow.python.framework import tensor_spec 34from tensorflow.python.framework import test_util 35from tensorflow.python.ops import array_grad # pylint: disable=unused-import 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import gen_ragged_conversion_ops 39from tensorflow.python.ops import gradients_impl 40from tensorflow.python.ops import map_fn 41from tensorflow.python.ops import math_grad # pylint: disable=unused-import 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 44from tensorflow.python.ops.ragged import ragged_factory_ops 45from tensorflow.python.ops.ragged import ragged_math_ops 46from tensorflow.python.ops.ragged import ragged_tensor 47from tensorflow.python.ops.ragged import ragged_tensor_value 48from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor 49from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec 50from tensorflow.python.ops.ragged.row_partition import RowPartition 51 52from tensorflow.python.platform import googletest 53from tensorflow.python.util import nest 54 55 56def int32array(values): 57 return np.array(values, dtype=np.int32) 58 59 60@test_util.run_all_in_graph_and_eager_modes 61class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): 62 longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name 63 64 #============================================================================= 65 # RaggedTensor class docstring examples 66 #============================================================================= 67 68 def testClassDocStringExamples(self): 69 # From section: "Component Tensors" 70 rt = RaggedTensor.from_row_splits( 71 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 72 self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 73 del rt 74 75 # From section: "Alternative Row-Partitioning Schemes" 76 values = [3, 1, 4, 1, 5, 9, 2, 6] 77 rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8]) 78 rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0]) 79 rt3 = RaggedTensor.from_value_rowids( 80 values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 81 rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8]) 82 rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8]) 83 for rt in (rt1, rt2, rt3, rt4, rt5): 84 self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 85 del rt1, rt2, rt3, rt4, rt5 86 87 # From section: "Multiple Ragged Dimensions" 88 inner_rt = RaggedTensor.from_row_splits( 89 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 90 outer_rt = RaggedTensor.from_row_splits( 91 values=inner_rt, row_splits=[0, 3, 3, 5]) 92 self.assertEqual(outer_rt.ragged_rank, 2) 93 self.assertAllEqual(outer_rt, 94 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 95 del inner_rt, outer_rt 96 97 # From section: "Multiple Ragged Dimensions" 98 rt = RaggedTensor.from_nested_row_splits( 99 flat_values=[3, 1, 4, 1, 5, 9, 2, 6], 100 nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])) 101 self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 102 del rt 103 104 # From section: "Uniform Inner Dimensions" 105 rt = RaggedTensor.from_row_splits( 106 values=array_ops.ones([5, 3]), row_splits=[0, 2, 5]) 107 self.assertAllEqual( 108 rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) 109 self.assertEqual(rt.shape.as_list(), [2, None, 3]) 110 del rt 111 112 #============================================================================= 113 # RaggedTensorValue Constructor 114 #============================================================================= 115 116 def testRaggedTensorValueConstruction(self): 117 values = np.array(b'a b c d e f g'.split()) 118 splits = np.array([0, 2, 5, 6, 6, 7], dtype=np.int64) 119 splits2 = np.array([0, 3, 5], dtype=np.int64) 120 121 # Test construction of a RaggedTensorValue with ragged_rank=1. 122 rt_value = ragged_tensor_value.RaggedTensorValue(values, splits) 123 self.assertEqual(rt_value.row_splits.dtype, np.int64) 124 self.assertEqual(rt_value.shape, (5, None)) 125 self.assertLen(rt_value.nested_row_splits, 1) 126 self.assertAllEqual(splits, rt_value.row_splits) 127 self.assertAllEqual(values, rt_value.values) 128 self.assertAllEqual(splits, rt_value.nested_row_splits[0]) 129 self.assertAllEqual(values, rt_value.flat_values) 130 131 # Test construction of a RaggedTensorValue with ragged_rank=2. 132 rt_value = ragged_tensor_value.RaggedTensorValue( 133 values=ragged_tensor_value.RaggedTensorValue(values, splits), 134 row_splits=splits2) 135 self.assertEqual(rt_value.row_splits.dtype, np.int64) 136 self.assertEqual(rt_value.shape, (2, None, None)) 137 self.assertLen(rt_value.nested_row_splits, 2) 138 self.assertAllEqual(splits2, rt_value.row_splits) 139 self.assertAllEqual(splits, rt_value.values.row_splits) 140 self.assertAllEqual(splits2, rt_value.nested_row_splits[0]) 141 self.assertAllEqual(splits, rt_value.nested_row_splits[1]) 142 self.assertAllEqual(values, rt_value.values.values) 143 self.assertAllEqual(values, rt_value.flat_values) 144 145 #============================================================================= 146 # RaggedTensor Constructor (private) 147 #============================================================================= 148 149 def testRaggedTensorConstruction(self): 150 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 151 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 152 rp = RowPartition.from_row_splits(row_splits) 153 rt = RaggedTensor(values=values, row_partition=rp, internal=True) 154 155 self.assertAllEqual(rt, 156 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 157 158 def testRaggedTensorConstructionErrors(self): 159 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 160 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 161 rp = RowPartition.from_row_splits(row_splits) 162 163 with self.assertRaisesRegex(ValueError, 164 'RaggedTensor constructor is private'): 165 RaggedTensor(values=values, row_partition=rp) 166 167 with self.assertRaisesRegex( 168 TypeError, 169 r"""type\(values\) must be one of: 'Tensor, RaggedTensor.*"""): 170 RaggedTensor(values=range(7), row_partition=rp, internal=True) 171 172 with self.assertRaisesRegex(TypeError, 173 'row_partition must be a RowPartition'): 174 RaggedTensor( 175 values=values, row_partition=[0, 2, 2, 5, 6, 7], internal=True) 176 177 #============================================================================= 178 # RaggedTensor Factory Ops 179 #============================================================================= 180 181 def testFromValueRowIdsWithDerivedNRows(self): 182 # nrows is known at graph creation time. 183 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 184 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 185 186 rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False) 187 self.assertEqual(rt.dtype, dtypes.string) 188 self.assertEqual(rt.shape.as_list(), [5, None]) 189 self.assertEqual(rt.ragged_rank, 1) 190 191 rt_values = rt.values 192 rt_value_rowids = rt.value_rowids() 193 rt_nrows = rt.nrows() 194 195 self.assertIs(rt_values, values) 196 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 197 self.assertAllEqual(rt_value_rowids, value_rowids) 198 self.assertAllEqual(rt_nrows, 5) 199 self.assertAllEqual(rt, 200 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 201 202 def testFromValueRowIdsWithDerivedNRowsDynamic(self): 203 # nrows is not known at graph creation time. 204 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 205 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 206 value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None) 207 208 rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False) 209 self.assertEqual(rt.dtype, dtypes.string) 210 if context.executing_eagerly(): 211 self.assertEqual(rt.shape.as_list(), [5, None]) 212 else: 213 self.assertEqual(rt.shape.as_list(), [None, None]) 214 self.assertEqual(rt.ragged_rank, 1) 215 216 rt_values = rt.values 217 rt_value_rowids = rt.value_rowids() 218 rt_nrows = rt.nrows() 219 220 self.assertIs(rt_values, values) 221 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 222 self.assertAllEqual(rt_value_rowids, value_rowids) 223 self.assertAllEqual(rt_nrows, 5) 224 self.assertAllEqual(rt, 225 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 226 227 def testFromValueRowIdsWithExplicitNRows(self): 228 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 229 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 230 nrows = constant_op.constant(7, dtypes.int64) 231 232 rt = RaggedTensor.from_value_rowids( 233 values, value_rowids, nrows, validate=False) 234 self.assertEqual(rt.dtype, dtypes.string) 235 self.assertEqual(rt.shape.as_list(), [7, None]) 236 self.assertEqual(rt.ragged_rank, 1) 237 238 rt_values = rt.values 239 rt_value_rowids = rt.value_rowids() 240 rt_nrows = rt.nrows() 241 242 self.assertIs(rt_values, values) 243 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 244 self.assertIs(rt_nrows, nrows) # cached_nrows 245 self.assertAllEqual( 246 rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []]) 247 248 def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self): 249 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 250 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 251 nrows = constant_op.constant(5, dtypes.int64) 252 253 rt = RaggedTensor.from_value_rowids( 254 values, value_rowids, nrows, validate=False) 255 self.assertEqual(rt.dtype, dtypes.string) 256 self.assertEqual(rt.shape.as_list(), [5, None]) 257 self.assertEqual(rt.ragged_rank, 1) 258 259 rt_values = rt.values 260 rt_value_rowids = rt.value_rowids() 261 rt_nrows = rt.nrows() 262 263 self.assertIs(rt_values, values) 264 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 265 self.assertIs(rt_nrows, nrows) # cached_nrows 266 self.assertAllEqual(rt_value_rowids, value_rowids) 267 self.assertAllEqual(rt_nrows, nrows) 268 self.assertAllEqual(rt, 269 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 270 271 def testFromValueRowIdsWithEmptyValues(self): 272 rt = RaggedTensor.from_value_rowids([], []) 273 rt_nrows = rt.nrows() 274 self.assertEqual(rt.dtype, dtypes.float32) 275 self.assertEqual(rt.shape.as_list(), [0, None]) 276 self.assertEqual(rt.ragged_rank, 1) 277 self.assertEqual(rt.values.shape.as_list(), [0]) 278 self.assertEqual(rt.value_rowids().shape.as_list(), [0]) 279 self.assertAllEqual(rt_nrows, 0) 280 self.assertAllEqual(rt, []) 281 282 def testFromRowSplits(self): 283 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 284 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 285 286 rt = RaggedTensor.from_row_splits(values, row_splits, validate=False) 287 self.assertEqual(rt.dtype, dtypes.string) 288 self.assertEqual(rt.shape.as_list(), [5, None]) 289 self.assertEqual(rt.ragged_rank, 1) 290 291 rt_values = rt.values 292 rt_row_splits = rt.row_splits 293 rt_nrows = rt.nrows() 294 295 self.assertIs(rt_values, values) 296 self.assertIs(rt_row_splits, row_splits) 297 self.assertAllEqual(rt_nrows, 5) 298 self.assertAllEqual(rt, 299 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 300 301 def testFromRowSplitsWithDifferentSplitTypes(self): 302 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 303 splits1 = [0, 2, 2, 5, 6, 7] 304 splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64) 305 splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32) 306 splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 307 splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32) 308 rt1 = RaggedTensor.from_row_splits(values, splits1) 309 rt2 = RaggedTensor.from_row_splits(values, splits2) 310 rt3 = RaggedTensor.from_row_splits(values, splits3) 311 rt4 = RaggedTensor.from_row_splits(values, splits4) 312 rt5 = RaggedTensor.from_row_splits(values, splits5) 313 self.assertEqual(rt1.row_splits.dtype, dtypes.int64) 314 self.assertEqual(rt2.row_splits.dtype, dtypes.int64) 315 self.assertEqual(rt3.row_splits.dtype, dtypes.int32) 316 self.assertEqual(rt4.row_splits.dtype, dtypes.int64) 317 self.assertEqual(rt5.row_splits.dtype, dtypes.int32) 318 319 def testFromRowSplitsWithEmptySplits(self): 320 err_msg = 'row_splits tensor may not be empty' 321 with self.assertRaisesRegex(ValueError, err_msg): 322 RaggedTensor.from_row_splits([], []) 323 324 def testFromRowStarts(self): 325 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 326 row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64) 327 328 rt = RaggedTensor.from_row_starts(values, row_starts, validate=False) 329 self.assertEqual(rt.dtype, dtypes.string) 330 self.assertEqual(rt.shape.as_list(), [5, None]) 331 self.assertEqual(rt.ragged_rank, 1) 332 333 rt_values = rt.values 334 rt_row_starts = rt.row_starts() 335 rt_nrows = rt.nrows() 336 337 self.assertIs(rt_values, values) 338 self.assertAllEqual(rt_nrows, 5) 339 self.assertAllEqual(rt_row_starts, row_starts) 340 self.assertAllEqual(rt, 341 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 342 343 def testFromRowLimits(self): 344 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 345 row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) 346 347 rt = RaggedTensor.from_row_limits(values, row_limits, validate=False) 348 self.assertEqual(rt.dtype, dtypes.string) 349 self.assertEqual(rt.shape.as_list(), [5, None]) 350 self.assertEqual(rt.ragged_rank, 1) 351 352 rt_values = rt.values 353 rt_row_limits = rt.row_limits() 354 rt_nrows = rt.nrows() 355 356 self.assertIs(rt_values, values) 357 self.assertAllEqual(rt_nrows, 5) 358 self.assertAllEqual(rt_row_limits, row_limits) 359 self.assertAllEqual(rt, 360 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 361 362 def testFromRowLengths(self): 363 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 364 row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) 365 366 rt = RaggedTensor.from_row_lengths(values, row_lengths, validate=False) 367 self.assertEqual(rt.dtype, dtypes.string) 368 self.assertEqual(rt.shape.as_list(), [5, None]) 369 self.assertEqual(rt.ragged_rank, 1) 370 371 rt_values = rt.values 372 rt_row_lengths = rt.row_lengths() 373 rt_nrows = rt.nrows() 374 375 self.assertIs(rt_values, values) 376 self.assertIs(rt_row_lengths, row_lengths) # cached_nrows 377 self.assertAllEqual(rt_nrows, 5) 378 self.assertAllEqual(rt_row_lengths, row_lengths) 379 self.assertAllEqual(rt, 380 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 381 382 def testFromRowLengthsInt32(self): 383 rt = RaggedTensor.from_row_lengths([1, 2, 3, 4], 384 constant_op.constant([1, 0, 3], 385 dtype=dtypes.int32)) 386 rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0]) 387 self.assertAllEqual([2, 1, 0], rt2.row_lengths()) 388 389 def testFromUniformRowLength(self): 390 values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] 391 392 a1 = RaggedTensor.from_uniform_row_length(values, 2) 393 a2 = RaggedTensor.from_uniform_row_length(values, 2, 8) 394 self.assertAllEqual( 395 a1, 396 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]) 397 self.assertAllEqual(a1, a2) 398 self.assertEqual(a1.shape.as_list(), [8, 2]) 399 self.assertEqual(a2.shape.as_list(), [8, 2]) 400 401 b1 = RaggedTensor.from_uniform_row_length(a1, 2) 402 b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4) 403 self.assertAllEqual(b1, [[[1, 2], [3, 4]], [[5, 6], [7, 8]], 404 [[9, 10], [11, 12]], [[13, 14], [15, 16]]]) 405 self.assertAllEqual(b1, b2) 406 self.assertEqual(b1.shape.as_list(), [4, 2, 2]) 407 self.assertEqual(b2.shape.as_list(), [4, 2, 2]) 408 409 c1 = RaggedTensor.from_uniform_row_length(b1, 2) 410 c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2) 411 self.assertAllEqual(c1, [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 412 [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]]) 413 self.assertAllEqual(c1, c2) 414 self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2]) 415 self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2]) 416 417 def testFromUniformRowLengthWithEmptyValues(self): 418 empty_values = [] 419 a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10) 420 self.assertEqual(a.shape.as_list(), [10, 0]) 421 422 b = RaggedTensor.from_uniform_row_length(a, 2) 423 self.assertEqual(b.shape.as_list(), [5, 2, 0]) 424 425 # Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0. 426 c = RaggedTensor.from_uniform_row_length(empty_values, 0) 427 self.assertEqual(c.shape.as_list(), [0, 0]) 428 d = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=0) 429 self.assertEqual(d.shape.as_list(), [0, 0]) 430 431 def testFromUniformRowLengthWithPlaceholders(self): 432 ph_values = array_ops.placeholder_with_default([1, 2, 3, 4, 5, 6], [None]) 433 ph_rowlen = array_ops.placeholder_with_default(3, None) 434 rt1 = RaggedTensor.from_uniform_row_length(ph_values, 3) 435 rt2 = RaggedTensor.from_uniform_row_length(ph_values, ph_rowlen) 436 rt3 = RaggedTensor.from_uniform_row_length([1, 2, 3, 4, 5, 6], ph_rowlen) 437 self.assertAllEqual(rt1, [[1, 2, 3], [4, 5, 6]]) 438 self.assertAllEqual(rt2, [[1, 2, 3], [4, 5, 6]]) 439 self.assertAllEqual(rt3, [[1, 2, 3], [4, 5, 6]]) 440 if context.executing_eagerly(): 441 self.assertEqual(rt1.shape.as_list(), [2, 3]) 442 self.assertEqual(rt2.shape.as_list(), [2, 3]) 443 self.assertEqual(rt3.shape.as_list(), [2, 3]) 444 else: 445 self.assertEqual(rt1.shape.as_list(), [None, 3]) 446 self.assertEqual(rt2.shape.as_list(), [None, None]) 447 self.assertEqual(rt3.shape.as_list(), [None, None]) 448 449 b = RaggedTensor.from_uniform_row_length(rt1, 2) 450 self.assertAllEqual(b, [[[1, 2, 3], [4, 5, 6]]]) 451 452 # Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0. 453 ph_empty_values = array_ops.placeholder_with_default( 454 array_ops.zeros([0], dtypes.int64), [None]) 455 ph_zero = array_ops.placeholder_with_default(0, []) 456 c = RaggedTensor.from_uniform_row_length(ph_empty_values, ph_zero) 457 if context.executing_eagerly(): 458 self.assertEqual(c.shape.as_list(), [0, 0]) 459 else: 460 self.assertEqual(c.shape.as_list(), [None, None]) 461 462 def testFromNestedValueRowIdsWithDerivedNRows(self): 463 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 464 nested_value_rowids = [ 465 constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), 466 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 467 ] 468 469 rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids) 470 self.assertEqual(rt.dtype, dtypes.string) 471 self.assertEqual(rt.shape.as_list(), [4, None, None]) 472 self.assertEqual(rt.ragged_rank, 2) 473 474 rt_values = rt.values 475 rt_value_rowids = rt.value_rowids() 476 rt_values_values = rt_values.values 477 rt_values_value_rowids = rt_values.value_rowids() 478 479 self.assertIs(rt_values_values, values) 480 self.assertAllEqual(rt_value_rowids, nested_value_rowids[0]) 481 self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1]) 482 self.assertAllEqual( 483 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 484 485 def testFromNestedValueRowIdsWithExplicitNRows(self): 486 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 487 nested_value_rowids = [ 488 constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64), 489 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 490 ] 491 nrows = [ 492 constant_op.constant(6, dtypes.int64), 493 constant_op.constant(6, dtypes.int64) 494 ] 495 496 rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, 497 nrows) 498 self.assertEqual(rt.dtype, dtypes.string) 499 self.assertEqual(rt.shape.as_list(), [6, None, None]) 500 self.assertEqual(rt.ragged_rank, 2) 501 502 rt_values = rt.values 503 rt_value_rowids = rt.value_rowids() 504 rt_nrows = rt.nrows() 505 rt_values_values = rt_values.values 506 rt_values_value_rowids = rt_values.value_rowids() 507 rt_values_nrows = rt_values.nrows() 508 509 self.assertIs(rt_values_values, values) 510 self.assertAllEqual(rt_value_rowids, nested_value_rowids[0]) 511 self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1]) 512 self.assertAllEqual(rt_nrows, nrows[0]) 513 self.assertAllEqual(rt_values_nrows, nrows[1]) 514 self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], 515 [[b'f'], [b'g'], []], [], []]) 516 517 def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self): 518 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 519 nested_value_rowids = [ 520 constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64), 521 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 522 ] 523 nrows = [constant_op.constant(6, dtypes.int64)] 524 with self.assertRaisesRegex( 525 ValueError, 'nested_nrows must have the same ' 526 'length as nested_value_rowids'): 527 RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, nrows) 528 529 def testFromNestedValueRowIdsWithNonListInput(self): 530 with self.assertRaisesRegex( 531 TypeError, 'nested_value_rowids must be a list of Tensors'): 532 RaggedTensor.from_nested_value_rowids( 533 [1, 2, 3], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64)) 534 with self.assertRaisesRegex(TypeError, 535 'nested_nrows must be a list of Tensors'): 536 RaggedTensor.from_nested_value_rowids([1, 2, 3], [[0, 1, 2], [0, 1, 2]], 537 constant_op.constant([3, 3])) 538 539 def testFromNestedRowSplits(self): 540 flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 541 nested_row_splits = [ 542 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 543 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 544 ] 545 546 rt = RaggedTensor.from_nested_row_splits( 547 flat_values, nested_row_splits, validate=False) 548 self.assertEqual(rt.dtype, dtypes.string) 549 self.assertEqual(rt.shape.as_list(), [4, None, None]) 550 self.assertEqual(rt.ragged_rank, 2) 551 552 rt_values = rt.values 553 rt_row_splits = rt.row_splits 554 rt_values_values = rt_values.values 555 rt_values_row_splits = rt_values.row_splits 556 557 self.assertIs(rt_values_values, flat_values) 558 self.assertIs(rt_row_splits, nested_row_splits[0]) 559 self.assertIs(rt_values_row_splits, nested_row_splits[1]) 560 self.assertAllEqual( 561 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 562 563 def testWithRowSplits(self): 564 flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 565 nested_row_splits = [ 566 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 567 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 568 ] 569 570 rt = RaggedTensor.from_nested_row_splits( 571 flat_values, nested_row_splits, validate=False) 572 573 rt = rt.with_row_splits_dtype(dtypes.int32) 574 575 self.assertEqual(rt.dtype, dtypes.string) 576 self.assertEqual(rt.shape.as_list(), [4, None, None]) 577 self.assertEqual(rt.ragged_rank, 2) 578 579 rt_values = rt.values 580 rt_row_splits = rt.row_splits 581 rt_values_values = rt_values.values 582 rt_values_row_splits = rt_values.row_splits 583 584 self.assertAllEqual(rt_values_values, flat_values) 585 self.assertAllEqual(rt_row_splits, nested_row_splits[0]) 586 self.assertAllEqual(rt_values_row_splits, nested_row_splits[1]) 587 self.assertAllEqual( 588 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 589 590 def testFromNestedRowSplitsWithNonListInput(self): 591 with self.assertRaisesRegex(TypeError, 592 'nested_row_splits must be a list of Tensors'): 593 RaggedTensor.from_nested_row_splits( 594 [1, 2], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64)) 595 596 def testFromValueRowIdsWithBadNRows(self): 597 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 598 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 599 nrows = constant_op.constant(5, dtypes.int64) 600 601 with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'): 602 RaggedTensor.from_value_rowids( 603 values=values, 604 value_rowids=array_ops.placeholder_with_default(value_rowids, None), 605 nrows=-2) 606 607 with self.assertRaisesRegex( 608 ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, ' 609 r'value_rowids\[-1\]=4'): 610 RaggedTensor.from_value_rowids( 611 values=values, value_rowids=value_rowids, nrows=2) 612 613 with self.assertRaisesRegex( 614 ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, ' 615 r'value_rowids\[-1\]=4'): 616 RaggedTensor.from_value_rowids( 617 values=values, value_rowids=value_rowids, nrows=4) 618 619 with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'): 620 RaggedTensor.from_value_rowids( 621 values=values, 622 value_rowids=array_ops.expand_dims(value_rowids, 1), 623 nrows=nrows) 624 625 with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'): 626 RaggedTensor.from_value_rowids( 627 values=values, 628 value_rowids=value_rowids, 629 nrows=array_ops.expand_dims(nrows, 0)) 630 631 def testCondWithTensorsFromValueIds(self): 632 # b/141166460 633 rt = RaggedTensor.from_value_rowids([1, 2, 3], [0, 0, 2]) 634 c = array_ops.placeholder_with_default(True, None) 635 result = control_flow_ops.cond(c, lambda: rt, lambda: rt) 636 self.assertAllEqual(rt, result) 637 638 def testGraphMismatch(self): 639 if not context.executing_eagerly(): 640 with ops.Graph().as_default(): 641 values = constant_op.constant([1, 2, 3], dtypes.int64) 642 with ops.Graph().as_default(): 643 splits = constant_op.constant([0, 2, 3], dtypes.int64) 644 with self.assertRaisesRegex(ValueError, 645 '.* must be from the same graph as .*'): 646 RaggedTensor.from_row_splits(values, splits) 647 648 #============================================================================= 649 # Ragged Value & Row-Partitioning Tensor Accessors 650 #============================================================================= 651 652 def testRaggedTensorAccessors_2d(self): 653 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 654 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 655 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 656 rt1 = RaggedTensor.from_row_splits(values, row_splits) 657 rt2 = RaggedTensor.from_value_rowids(values, value_rowids) 658 659 for rt in [rt1, rt2]: 660 self.assertAllEqual( 661 rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 662 self.assertAllEqual(rt.values, [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) 663 self.assertEqual(rt.values.shape.dims[0].value, 7) 664 self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4]) 665 self.assertAllEqual(rt.nrows(), 5) 666 self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7]) 667 self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6]) 668 self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7]) 669 self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1]) 670 self.assertAllEqual(rt.flat_values, 671 [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) 672 self.assertLen(rt.nested_row_splits, 1) 673 self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7]) 674 675 def testRaggedTensorAccessors_3d_with_ragged_rank_1(self): 676 values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]] 677 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 678 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 679 row_lengths = constant_op.constant([2, 0, 3, 1, 1]) 680 rt1 = RaggedTensor.from_row_splits(values, row_splits) 681 rt2 = RaggedTensor.from_value_rowids(values, value_rowids) 682 rt3 = RaggedTensor.from_row_lengths(values, row_lengths) 683 684 for rt in [rt1, rt2, rt3]: 685 self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], 686 [[10, 11]], [[12, 13]]]) 687 self.assertAllEqual( 688 rt.values, 689 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]) 690 self.assertEqual(rt.values.shape.dims[0].value, 7) 691 self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4]) 692 self.assertAllEqual(rt.nrows(), 5) 693 self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7]) 694 self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6]) 695 self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7]) 696 self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1]) 697 self.assertAllEqual(rt.row_lengths(axis=2), 698 [[2, 2], [], [2, 2, 2], [2], [2]]) 699 self.assertAllEqual( 700 rt.flat_values, 701 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]) 702 self.assertLen(rt.nested_row_splits, 1) 703 self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7]) 704 self.assertLen(rt.nested_value_rowids(), 1) 705 706 self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 2, 2, 2, 3, 4]) 707 708 def testRaggedTensorAccessors_3d_with_ragged_rank_2(self): 709 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 710 nested_row_splits = [ 711 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 712 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 713 ] 714 nested_value_rowids = [ 715 constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), 716 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 717 ] 718 rt1 = RaggedTensor.from_nested_row_splits(values, nested_row_splits) 719 rt2 = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids) 720 721 for rt in [rt1, rt2]: 722 self.assertAllEqual( 723 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 724 self.assertAllEqual( 725 rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 726 self.assertEqual(rt.values.shape.dims[0].value, 5) 727 self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3]) 728 self.assertAllEqual(rt.nrows(), 4) 729 self.assertAllEqual(rt.row_splits, [0, 2, 3, 3, 5]) 730 self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3]) 731 self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5]) 732 self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2]) 733 self.assertAllEqual(rt.flat_values, 734 [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) 735 self.assertLen(rt.nested_row_splits, 2) 736 self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5]) 737 self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7]) 738 self.assertLen(rt.nested_value_rowids(), 2) 739 self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 1, 3, 3]) 740 self.assertAllEqual(rt.nested_value_rowids()[1], [0, 0, 2, 2, 2, 3, 4]) 741 742 #============================================================================= 743 # RaggedTensor.shape 744 #============================================================================= 745 746 def testShape(self): 747 """Tests for RaggedTensor.shape.""" 748 rt1 = RaggedTensor.from_row_splits(b'a b c d e f g'.split(), 749 [0, 2, 5, 6, 6, 7]) 750 self.assertEqual(rt1.shape.as_list(), [5, None]) 751 752 rt2 = RaggedTensor.from_row_splits( 753 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]], 754 [0, 2, 5, 6, 6, 7]) 755 self.assertEqual(rt2.shape.as_list(), [5, None, 2]) 756 757 rt3 = RaggedTensor.from_row_splits( 758 [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], [0, 2, 2, 3]) 759 self.assertEqual(rt3.shape.as_list(), [3, None, 2, 2]) 760 761 rt4 = RaggedTensor.from_row_splits(rt3, [0, 1, 3, 3]) 762 self.assertEqual(rt4.shape.as_list(), [3, None, None, 2, 2]) 763 764 if not context.executing_eagerly(): 765 rt5 = RaggedTensor.from_row_splits( 766 array_ops.placeholder(dtype=dtypes.string), [0, 2, 3, 5]) 767 self.assertIsNone(rt5.shape.ndims) 768 769 rt6 = RaggedTensor.from_row_splits( 770 [1, 2, 3], array_ops.placeholder(dtype=dtypes.int64)) 771 self.assertEqual(rt6.shape.as_list(), [None, None]) 772 773 def testGetShape(self): 774 rt = RaggedTensor.from_row_splits(b'a b c d e f g'.split(), 775 [0, 2, 5, 6, 6, 7]) 776 self.assertEqual(rt.shape.as_list(), rt.get_shape().as_list()) 777 778 #============================================================================= 779 # RaggedTensor.__str__ 780 #============================================================================= 781 def testRaggedTensorStr(self): 782 values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g'] 783 row_splits = [0, 2, 5, 6, 6, 7] 784 rt = RaggedTensor.from_row_splits(values, row_splits, validate=False) 785 splits_type = 'int64' 786 if context.executing_eagerly(): 787 expected_repr = '<tf.RaggedTensor {}>'.format([[b'a', b'b'], 788 [b'c', b'd', b'e'], [b'f'], 789 [], [b'g']]) 790 else: 791 expected_repr = ( 792 'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", ' 793 'shape=(7,), dtype=string), ' 794 'row_splits=Tensor(' 795 '"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",' 796 ' shape=(6,), dtype={}))').format(splits_type) 797 self.assertEqual(repr(rt), expected_repr) 798 self.assertEqual(str(rt), expected_repr) 799 800 def testRaggedTensorValueStr(self): 801 values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g'] 802 row_splits = [0, 2, 5, 6, 6, 7] 803 rt = ragged_tensor_value.RaggedTensorValue( 804 np.array(values), np.array(row_splits, dtype=np.int64)) 805 expected_str = '<tf.RaggedTensorValue {}>'.format([[b'a', b'b'], 806 [b'c', b'd', b'e'], 807 [b'f'], [], [b'g']]) 808 expected_repr = ("tf.RaggedTensorValue(values=array({}, dtype='|S1'), " 809 'row_splits=array({}))'.format(values, row_splits)) 810 self.assertEqual(' '.join(str(rt).split()), expected_str) 811 self.assertEqual(' '.join(repr(rt).split()), expected_repr) 812 813 #============================================================================= 814 # RaggedTensor.with_values() and RaggedTensor.with_flat_values(). 815 #============================================================================= 816 817 def testWithValues(self): 818 rt1 = ragged_factory_ops.constant([[1, 2], [3, 4, 5], [6], [], [7]]) 819 rt2 = ragged_factory_ops.constant([[[1, 2], [3, 4, 5]], [[6]], [], [[], 820 [7]]]) 821 822 rt1_plus_10 = rt1.with_values(rt1.values + 10) 823 rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10) 824 rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1)) 825 826 self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]]) 827 self.assertAllEqual(rt2_times_10, 828 [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]]) 829 self.assertAllEqual(rt1_expanded, 830 [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]]) 831 832 #============================================================================= 833 # Session.run 834 #============================================================================= 835 def testSessionRun(self): 836 if context.executing_eagerly(): 837 return 838 839 rt1 = ragged_factory_ops.constant([[1, 2, 3], [4]]) 840 rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]]) 841 with self.test_session() as session: 842 result = session.run({'rt1': rt1, 'rt2': rt2}) 843 self.assertCountEqual(result.keys(), ['rt1', 'rt2']) 844 self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]]) 845 self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]]) 846 847 def testSessionRunFeed(self): 848 if context.executing_eagerly(): 849 return 850 851 rt1 = RaggedTensor.from_row_splits( 852 array_ops.placeholder(dtypes.int32), 853 array_ops.placeholder(dtypes.int64)) 854 rt2 = RaggedTensor.from_nested_row_splits( 855 array_ops.placeholder(dtypes.int32), [ 856 array_ops.placeholder(dtypes.int64), 857 array_ops.placeholder(dtypes.int64) 858 ]) 859 860 rt1_feed_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]]) 861 rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]]) 862 863 with self.test_session() as session: 864 fetches = {'rt1': rt1, 'rt2': rt2} 865 feeds = {rt1: rt1_feed_val, rt2: rt2_feed_val} 866 result = session.run(fetches, feed_dict=feeds) 867 self.assertCountEqual(result.keys(), ['rt1', 'rt2']) 868 self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]]) 869 self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]]) 870 871 def testSessionPartialRunFeed(self): 872 if context.executing_eagerly(): 873 return 874 875 # Placeholder inputs. 876 a = RaggedTensor.from_row_splits( 877 array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'), 878 array_ops.placeholder(dtypes.int64, name='a.row_splits')) 879 b = RaggedTensor.from_row_splits( 880 array_ops.placeholder(dtypes.int32, shape=[None], name='b.values'), 881 array_ops.placeholder(dtypes.int64, name='b.row_splits')) 882 c = array_ops.placeholder(dtypes.int32, shape=[], name='c') 883 884 # Feed values for placeholder inputs. 885 a_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]]) 886 b_val = ragged_factory_ops.constant_value([[5, 4, 3], [2]]) 887 c_val = 3 888 889 # Compute some values. 890 r1 = ragged_math_ops.reduce_sum(a * b, axis=1) 891 r2 = ragged_math_ops.reduce_sum(a + c, axis=1) 892 893 with self.test_session() as session: 894 handle = session.partial_run_setup([r1, r2], [a, b, c]) 895 896 res1 = session.partial_run(handle, r1, feed_dict={a: a_val, b: b_val}) 897 self.assertAllEqual(res1, [22, 8]) 898 899 res2 = session.partial_run(handle, r2, feed_dict={c: c_val}) 900 self.assertAllEqual(res2, [15, 7]) 901 902 # Test case for GitHub issue 24679. 903 def testEagerForLoop(self): 904 if not context.executing_eagerly(): 905 return 906 907 values = [[1., 2.], [3., 4., 5.], [6.]] 908 r = ragged_factory_ops.constant(values) 909 i = 0 910 for elem in r: 911 self.assertAllEqual(elem, values[i]) 912 i += 1 913 914 def testConsumers(self): 915 if context.executing_eagerly(): 916 return 917 918 a = RaggedTensor.from_row_splits( 919 array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'), 920 array_ops.placeholder(dtypes.int64, name='a.row_splits'), 921 validate=False) 922 ragged_math_ops.reduce_sum(a) 923 self.assertLen(a.consumers(), 1) 924 925 @parameterized.parameters([ 926 { 927 'descr': 'from_value_rowids', 928 'factory': RaggedTensor.from_value_rowids, 929 'test': RaggedTensor.value_rowids, 930 'values': { 931 'values': [1, 2, 3, 4, 5, 6], 932 'value_rowids': [0, 0, 1, 1, 2, 2], 933 }, 934 'tensor_field': 'value_rowids', 935 'value_rowids': [0, 1, 2], 936 'nrows': 10 937 }, 938 { 939 'descr': 'from_row_splits', 940 'factory': RaggedTensor.from_row_splits, 941 # row_splits is a property, not a function. 942 'test': (lambda rt: rt.row_splits), 943 'values': { 944 'values': [1, 2, 3, 4, 5, 6], 945 'row_splits': [0, 2, 4, 6], 946 }, 947 'tensor_field': 'row_splits', 948 'row_splits': [0, 1, 2, 3] 949 }, 950 { 951 'descr': 'from_row_lengths', 952 'factory': RaggedTensor.from_row_lengths, 953 'test': RaggedTensor.row_lengths, 954 'values': { 955 'values': [1, 2, 3, 4, 5, 6], 956 'row_lengths': [2, 2, 2], 957 }, 958 'tensor_field': 'row_lengths', 959 'row_lengths': [1, 1, 1], 960 }, 961 # from_row_starts 962 { 963 'descr': 'from_row_starts', 964 'factory': RaggedTensor.from_row_starts, 965 'test': RaggedTensor.row_starts, 966 'values': { 967 'values': [1, 2, 3, 4, 5, 6], 968 'row_starts': [0, 2, 4] 969 }, 970 'tensor_field': 'row_starts', 971 'row_starts': [0, 1, 2] 972 }, 973 # from_row_limits 974 { 975 'descr': 'from_row_limits', 976 'factory': RaggedTensor.from_row_limits, 977 'test': RaggedTensor.row_limits, 978 'values': { 979 'values': [1, 2, 3, 4, 5, 6], 980 'row_limits': [2, 4, 6] 981 }, 982 'tensor_field': 'row_limits', 983 'row_limits': [3] 984 }, 985 # from_uniform_row_length 986 { 987 'descr': 'from_uniform_row_length', 988 'factory': RaggedTensor.from_uniform_row_length, 989 # One cannot extract uniform_row_length or nvals, so we return 990 # nvals//nrows = uniform_row_length, where nvals = 3 991 'test': (lambda rt: 3 // (rt.shape[0])), 992 'values': { 993 'values': [1, 2, 3, 4, 5, 6], 994 'uniform_row_length': 2 995 }, 996 'tensor_field': 'uniform_row_length', 997 'uniform_row_length': 3 998 }, 999 ]) 1000 def testFactoryTypePreference(self, descr, test, factory, values, 1001 tensor_field, **kwargs): 1002 # When input tensors have shape information, some of these errors will be 1003 # detected statically. 1004 def op_cast(k, v): 1005 if k == tensor_field: 1006 return constant_op.constant(v, dtype=dtypes.int32) 1007 else: 1008 return v 1009 1010 value_copy = {k: op_cast(k, v) for k, v in values.items()} 1011 rt = factory(**value_copy) 1012 1013 kw_copy = {k: v for k, v in kwargs.items()} 1014 kw_copy['values'] = rt 1015 rt2 = factory(**kw_copy) 1016 self.assertAllEqual(kwargs[tensor_field], test(rt2)) 1017 1018 @parameterized.parameters([ 1019 # from_value_rowids 1020 { 1021 'descr': 'bad rank for value_rowids', 1022 'factory': RaggedTensor.from_value_rowids, 1023 'values': [[1, 2], [3, 4]], 1024 'value_rowids': [[1, 2], [3, 4]], 1025 'nrows': 10 1026 }, 1027 { 1028 'descr': 'bad rank for nrows', 1029 'factory': RaggedTensor.from_value_rowids, 1030 'values': [1, 2, 3, 4], 1031 'value_rowids': [1, 2, 3, 4], 1032 'nrows': [10] 1033 }, 1034 { 1035 'descr': 'len(values) != len(value_rowids)', 1036 'factory': RaggedTensor.from_value_rowids, 1037 'values': [1, 2, 3, 4], 1038 'value_rowids': [1, 2, 3, 4, 5], 1039 'nrows': 10 1040 }, 1041 { 1042 'descr': 'negative value_rowid', 1043 'factory': RaggedTensor.from_value_rowids, 1044 'values': [1, 2, 3, 4], 1045 'value_rowids': [-5, 2, 3, 4], 1046 'nrows': 10 1047 }, 1048 { 1049 'descr': 'non-monotonic-increasing value_rowid', 1050 'factory': RaggedTensor.from_value_rowids, 1051 'values': [1, 2, 3, 4], 1052 'value_rowids': [4, 3, 2, 1], 1053 'nrows': 10 1054 }, 1055 { 1056 'descr': 'value_rowid > nrows', 1057 'factory': RaggedTensor.from_value_rowids, 1058 'values': [1, 2, 3, 4], 1059 'value_rowids': [1, 2, 3, 4], 1060 'nrows': 2 1061 }, 1062 { 1063 'descr': 'bad rank for values', 1064 'factory': RaggedTensor.from_value_rowids, 1065 'values': 10, 1066 'value_rowids': [1, 2, 3, 4], 1067 'nrows': 10 1068 }, 1069 1070 # from_row_splits 1071 { 1072 'descr': 'bad rank for row_splits', 1073 'factory': RaggedTensor.from_row_splits, 1074 'values': [[1, 2], [3, 4]], 1075 'row_splits': [[1, 2], [3, 4]] 1076 }, 1077 { 1078 'descr': 'row_splits[0] != 0', 1079 'factory': RaggedTensor.from_row_splits, 1080 'values': [1, 2, 3, 4], 1081 'row_splits': [2, 3, 4] 1082 }, 1083 { 1084 'descr': 'non-monotonic-increasing row_splits', 1085 'factory': RaggedTensor.from_row_splits, 1086 'values': [1, 2, 3, 4], 1087 'row_splits': [0, 3, 2, 4] 1088 }, 1089 { 1090 'descr': 'row_splits[0] != nvals', 1091 'factory': RaggedTensor.from_row_splits, 1092 'values': [1, 2, 3, 4], 1093 'row_splits': [0, 2, 3, 5] 1094 }, 1095 { 1096 'descr': 'bad rank for values', 1097 'factory': RaggedTensor.from_row_splits, 1098 'values': 10, 1099 'row_splits': [0, 1] 1100 }, 1101 1102 # from_row_lengths 1103 { 1104 'descr': 'bad rank for row_lengths', 1105 'factory': RaggedTensor.from_row_lengths, 1106 'values': [1, 2, 3, 4], 1107 'row_lengths': [[1, 2], [1, 0]] 1108 }, 1109 { 1110 'descr': 'negatve row_lengths', 1111 'factory': RaggedTensor.from_row_lengths, 1112 'values': [1, 2, 3, 4], 1113 'row_lengths': [3, -1, 2] 1114 }, 1115 { 1116 'descr': 'sum(row_lengths) != nvals', 1117 'factory': RaggedTensor.from_row_lengths, 1118 'values': [1, 2, 3, 4], 1119 'row_lengths': [2, 4, 2, 8] 1120 }, 1121 { 1122 'descr': 'bad rank for values', 1123 'factory': RaggedTensor.from_row_lengths, 1124 'values': 10, 1125 'row_lengths': [0, 1] 1126 }, 1127 1128 # from_row_starts 1129 { 1130 'descr': 'bad rank for row_starts', 1131 'factory': RaggedTensor.from_row_starts, 1132 'values': [[1, 2], [3, 4]], 1133 'row_starts': [[1, 2], [3, 4]] 1134 }, 1135 { 1136 'descr': 'row_starts[0] != 0', 1137 'factory': RaggedTensor.from_row_starts, 1138 'values': [1, 2, 3, 4], 1139 'row_starts': [2, 3, 4] 1140 }, 1141 { 1142 'descr': 'non-monotonic-increasing row_starts', 1143 'factory': RaggedTensor.from_row_starts, 1144 'values': [1, 2, 3, 4], 1145 'row_starts': [0, 3, 2, 4] 1146 }, 1147 { 1148 'descr': 'row_starts[0] > nvals', 1149 'factory': RaggedTensor.from_row_starts, 1150 'values': [1, 2, 3, 4], 1151 'row_starts': [0, 2, 3, 5] 1152 }, 1153 { 1154 'descr': 'bad rank for values', 1155 'factory': RaggedTensor.from_row_starts, 1156 'values': 10, 1157 'row_starts': [0, 1] 1158 }, 1159 1160 # from_row_limits 1161 { 1162 'descr': 'bad rank for row_limits', 1163 'factory': RaggedTensor.from_row_limits, 1164 'values': [[1, 2], [3, 4]], 1165 'row_limits': [[1, 2], [3, 4]] 1166 }, 1167 { 1168 'descr': 'row_limits[0] < 0', 1169 'factory': RaggedTensor.from_row_limits, 1170 'values': [1, 2, 3, 4], 1171 'row_limits': [-1, 3, 4] 1172 }, 1173 { 1174 'descr': 'non-monotonic-increasing row_limits', 1175 'factory': RaggedTensor.from_row_limits, 1176 'values': [1, 2, 3, 4], 1177 'row_limits': [0, 3, 2, 4] 1178 }, 1179 { 1180 'descr': 'row_limits[0] != nvals', 1181 'factory': RaggedTensor.from_row_limits, 1182 'values': [1, 2, 3, 4], 1183 'row_limits': [0, 2, 3, 5] 1184 }, 1185 { 1186 'descr': 'bad rank for values', 1187 'factory': RaggedTensor.from_row_limits, 1188 'values': 10, 1189 'row_limits': [0, 1] 1190 }, 1191 1192 # from_uniform_row_length 1193 { 1194 'descr': 'rowlen * nrows != nvals (1)', 1195 'factory': RaggedTensor.from_uniform_row_length, 1196 'values': [1, 2, 3, 4, 5], 1197 'uniform_row_length': 3 1198 }, 1199 { 1200 'descr': 'rowlen * nrows != nvals (2)', 1201 'factory': RaggedTensor.from_uniform_row_length, 1202 'values': [1, 2, 3, 4, 5], 1203 'uniform_row_length': 6 1204 }, 1205 { 1206 'descr': 'rowlen * nrows != nvals (3)', 1207 'factory': RaggedTensor.from_uniform_row_length, 1208 'values': [1, 2, 3, 4, 5, 6], 1209 'uniform_row_length': 3, 1210 'nrows': 3 1211 }, 1212 { 1213 'descr': 'rowlen must be a scalar', 1214 'factory': RaggedTensor.from_uniform_row_length, 1215 'values': [1, 2, 3, 4], 1216 'uniform_row_length': [2] 1217 }, 1218 { 1219 'descr': 'rowlen must be nonnegative', 1220 'factory': RaggedTensor.from_uniform_row_length, 1221 'values': [1, 2, 3, 4], 1222 'uniform_row_length': -1 1223 }, 1224 ]) 1225 def testFactoryValidation(self, descr, factory, **kwargs): 1226 # When input tensors have shape information, some of these errors will be 1227 # detected statically. 1228 with self.assertRaises((errors.InvalidArgumentError, ValueError)): 1229 self.evaluate(factory(**kwargs)) 1230 1231 # Remove shape information (by wrapping tensors in placeholders), and check 1232 # that we detect the errors when the graph is run. 1233 if not context.executing_eagerly(): 1234 1235 def wrap_arg(v): 1236 return array_ops.placeholder_with_default( 1237 constant_op.constant(v, dtype=dtypes.int64), 1238 tensor_shape.TensorShape(None)) 1239 1240 kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items()) 1241 1242 with self.assertRaises(errors.InvalidArgumentError): 1243 self.evaluate(factory(**kwargs)) 1244 1245 #============================================================================= 1246 # RaggedTensor Variant conversion 1247 #============================================================================= 1248 1249 @parameterized.named_parameters( 1250 { 1251 'testcase_name': 'Shape_5_none', 1252 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 1253 'ragged_rank': 1 1254 }, { 1255 'testcase_name': 'Shape_4_none_2', 1256 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 1257 'ragged_rank': 1 1258 }, { 1259 'testcase_name': 'Shape_1_none_none', 1260 'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]], 1261 'ragged_rank': 2 1262 }) 1263 def testRaggedToVariant(self, ragged_constant, ragged_rank): 1264 rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank) 1265 et = rt._to_variant() 1266 self.assertEqual(et.shape.as_list(), []) 1267 self.assertEqual(et.dtype, dtypes.variant) 1268 1269 @parameterized.parameters( 1270 { 1271 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 1272 'ragged_rank': 1, 1273 'num_batched_elems': 5 1274 }, { 1275 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 1276 'ragged_rank': 1, 1277 'num_batched_elems': 4 1278 }, { 1279 'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]], 1280 'ragged_rank': 2, 1281 'num_batched_elems': 2 1282 }) 1283 def testRaggedToBatchedVariant(self, ragged_constant, ragged_rank, 1284 num_batched_elems): 1285 rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank) 1286 et = rt._to_variant(batched_input=True) 1287 self.assertEqual(et.shape.as_list(), [num_batched_elems]) 1288 self.assertEqual(et.dtype, dtypes.variant) 1289 1290 @parameterized.parameters( 1291 # 2D test cases. 1292 { 1293 'ragged_constant': [[]], 1294 'ragged_rank': 1, 1295 }, 1296 { 1297 'ragged_constant': [[1]], 1298 'ragged_rank': 1, 1299 }, 1300 { 1301 'ragged_constant': [[1, 2]], 1302 'ragged_rank': 1, 1303 }, 1304 { 1305 'ragged_constant': [[1], [2], [3]], 1306 'ragged_rank': 1, 1307 }, 1308 { 1309 'ragged_constant': [[1, 2, 3], [4, 5, 6], [7, 8, 9]], 1310 'ragged_rank': 1, 1311 }, 1312 { 1313 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 1314 'ragged_rank': 1, 1315 }, 1316 # 3D test cases. 1317 { 1318 'ragged_constant': [[[]]], 1319 'ragged_rank': 2, 1320 }, 1321 { 1322 'ragged_constant': [[[1]]], 1323 'ragged_rank': 2, 1324 }, 1325 { 1326 'ragged_constant': [[[1, 2]]], 1327 'ragged_rank': 2, 1328 }, 1329 { 1330 'ragged_constant': [[[1, 2], [3, 4]]], 1331 'ragged_rank': 2, 1332 }, 1333 { 1334 'ragged_constant': [[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]], 1335 'ragged_rank': 2, 1336 }, 1337 { 1338 'ragged_constant': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]], 1339 'ragged_rank': 2, 1340 }, 1341 { 1342 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 1343 'ragged_rank': 2, 1344 }, 1345 # 4D test cases. 1346 { 1347 'ragged_constant': [[[[1, 2], [3, 4]]], 1348 [[[0, 0], [0, 0]], [[5, 6], [7, 8]]], []], 1349 'ragged_rank': 3, 1350 }, 1351 # dtype `string`. 1352 { 1353 'ragged_constant': [['a'], ['b'], ['c']], 1354 'ragged_rank': 1, 1355 'dtype': dtypes.string, 1356 }, 1357 { 1358 'ragged_constant': [[['a', 'b'], ['c', 'd']]], 1359 'ragged_rank': 2, 1360 'dtype': dtypes.string, 1361 }, 1362 { 1363 'ragged_constant': [[[['a', 'b'], ['c', 'd']]], 1364 [[['e', 'f'], ['g', 'h']], [['i', 'j'], 1365 ['k', 'l']]], []], 1366 'ragged_rank': 3, 1367 'dtype': dtypes.string, 1368 }) 1369 def testVariantRoundTrip(self, 1370 ragged_constant, 1371 ragged_rank, 1372 dtype=dtypes.int32): 1373 rt = ragged_factory_ops.constant( 1374 ragged_constant, ragged_rank=ragged_rank, dtype=dtype) 1375 et = rt._to_variant() 1376 round_trip_rt = RaggedTensor._from_variant( 1377 et, dtype, output_ragged_rank=ragged_rank) 1378 self.assertAllEqual(rt, round_trip_rt) 1379 1380 def testBatchedVariantRoundTripInputRaggedRankInferred(self): 1381 ragged_rank = 1 1382 rt = ragged_factory_ops.constant( 1383 [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], 1384 ragged_rank=ragged_rank) 1385 batched_variant = rt._to_variant(batched_input=True) 1386 nested_batched_variant = array_ops.reshape(batched_variant, [5, 2]) 1387 decoded_rt = RaggedTensor._from_variant( 1388 nested_batched_variant, 1389 dtype=dtypes.int32, 1390 output_ragged_rank=ragged_rank + 1) 1391 expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4], 1392 [5]], 1393 [[6], [7]], [[8], [9]]]) 1394 self.assertAllEqual(decoded_rt, expected_rt) 1395 1396 def testBatchedVariantRoundTripWithInputRaggedRank(self): 1397 ragged_rank = 1 1398 rt = ragged_factory_ops.constant( 1399 [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], 1400 ragged_rank=ragged_rank) 1401 batched_variant = rt._to_variant(batched_input=True) 1402 nested_batched_variant = array_ops.reshape(batched_variant, [5, 2]) 1403 decoded_rt = RaggedTensor._from_variant( 1404 nested_batched_variant, 1405 dtype=dtypes.int32, 1406 output_ragged_rank=ragged_rank + 1, 1407 input_ragged_rank=ragged_rank - 1) 1408 expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4], 1409 [5]], 1410 [[6], [7]], [[8], [9]]]) 1411 self.assertAllEqual(decoded_rt, expected_rt) 1412 1413 def testUnbatchVariant(self): # b/141789000 1414 rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]]) 1415 batched = rt._to_variant(batched_input=True) 1416 for i in range(4): 1417 row = RaggedTensor._from_variant( 1418 batched[i], dtype=dtypes.int32, output_ragged_rank=0) 1419 self.assertAllEqual(rt[i], row) 1420 1421 def testUnbatchVariantInDataset(self): 1422 rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]]) 1423 ds = dataset_ops.Dataset.from_tensor_slices(rt) 1424 if context.executing_eagerly(): 1425 for i, value in enumerate(ds): 1426 self.assertAllEqual(rt[i], value) 1427 else: 1428 it = dataset_ops.make_one_shot_iterator(ds) 1429 out = it.get_next() 1430 with self.cached_session() as sess: 1431 for i in range(3): 1432 self.assertAllEqual(sess.run(rt[i]), out) 1433 1434 def testFromVariantInvalidParams(self): 1435 rt = ragged_factory_ops.constant([[0], [1], [2], [3]]) 1436 batched_variant = rt._to_variant(batched_input=True) 1437 nested_batched_variant = array_ops.reshape(batched_variant, [2, 2]) 1438 with self.assertRaisesRegex(ValueError, 1439 'output_ragged_rank must be equal to'): 1440 RaggedTensor._from_variant( 1441 nested_batched_variant, 1442 dtype=dtypes.int32, 1443 output_ragged_rank=1, 1444 input_ragged_rank=1) 1445 1446 def _testRaggedVarientGradient(self, func, x, expected_grad): 1447 x = constant_op.constant(x) 1448 if context.executing_eagerly(): 1449 with backprop.GradientTape() as t: 1450 t.watch(x) 1451 y = func(x) 1452 g = t.gradient(y, x) 1453 else: 1454 y = func(x) 1455 g = gradients_impl.gradients(ys=y, xs=x)[0] 1456 self.assertAllClose(g, expected_grad) 1457 1458 def testRaggedVariantGradients(self): 1459 def func(x): 1460 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1461 rt2 = rt1 * [[10], [100], [1000]] 1462 v = rt2._to_variant(batched_input=False) 1463 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1464 return rt3.flat_values 1465 1466 self._testRaggedVarientGradient( 1467 func, 1468 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1469 [10., 10., 10., 10., 100., 100., 100., 1000.]) 1470 1471 def testRaggedVariantGradientsBatched(self): 1472 def func(x): 1473 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1474 rt2 = rt1 * [[10], [100], [1000]] 1475 v = rt2._to_variant(batched_input=True) 1476 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1477 return rt3.flat_values 1478 1479 self._testRaggedVarientGradient( 1480 func, 1481 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1482 [10., 10., 10., 10., 100., 100., 100., 1000.]) 1483 1484 def testRaggedVariantGradientsBatchedAndSliced(self): 1485 def func(x, i): 1486 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1487 rt2 = rt1 * [[10], [100], [1000]] 1488 v_slice = rt2._to_variant(batched_input=True)[i] 1489 return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype, 1490 output_ragged_rank=0) 1491 1492 self._testRaggedVarientGradient( 1493 functools.partial(func, i=0), 1494 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1495 [10., 10., 10., 10., 0., 0., 0., 0.]) 1496 self._testRaggedVarientGradient( 1497 functools.partial(func, i=1), 1498 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1499 [0., 0., 0., 0., 100., 100., 100., 0.]) 1500 self._testRaggedVarientGradient( 1501 functools.partial(func, i=2), 1502 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1503 [0., 0., 0., 0., 0., 0., 0., 1000.]) 1504 1505 def testRaggedVariantGradientsRaggedRank0(self): 1506 def func(x): 1507 x2 = x * 2 1508 v = gen_ragged_conversion_ops.ragged_tensor_to_variant( 1509 [], x2, batched_input=False) 1510 return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0) 1511 1512 self._testRaggedVarientGradient( 1513 func, 1514 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1515 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) 1516 1517 def testRaggedVariantGradientsRaggedRank3(self): 1518 def func(x): 1519 x2 = x * 2 1520 rt1 = RaggedTensor.from_nested_row_splits( 1521 x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8])) 1522 v = rt1._to_variant(batched_input=False) 1523 rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3) 1524 return rt3.flat_values 1525 1526 self._testRaggedVarientGradient( 1527 func, 1528 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1529 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) 1530 1531 def testRaggedVariantGradientsViaMapFn(self): 1532 rt = RaggedTensor.from_row_splits( 1533 values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8]) 1534 1535 def func(x): 1536 1537 def transform_row(row): 1538 return math_ops.sqrt( 1539 math_ops.reduce_mean(math_ops.square(row * x), keepdims=True)) 1540 1541 return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt)) 1542 1543 self._testRaggedVarientGradient(func, 3.0, 14.653377) 1544 1545 def testRaggedVariantGradientsViaMapFnReduce(self): 1546 def func(x): 1547 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1548 return map_fn.map_fn( 1549 math_ops.reduce_max, rt1, 1550 fn_output_signature=tensor_spec.TensorSpec((), x.dtype)) 1551 1552 self._testRaggedVarientGradient( 1553 func, 1554 [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1555 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) 1556 1557 def testRaggedVariantGradientsErrors(self): 1558 if context.executing_eagerly(): 1559 return 1560 1561 rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2]) 1562 v1 = rt._to_variant() 1563 v2 = array_ops.stack([array_ops.stack([v1])]) 1564 y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3) 1565 1566 with self.assertRaisesRegex( 1567 ValueError, 'Unable to compute gradient: RaggedTensorToVariant ' 1568 'can currently only generate 0D or 1D output.'): 1569 gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values) 1570 1571 def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg): 1572 """Check that two numpy arrays are equal. 1573 1574 For arrays with dtype=object, check values recursively to see if a and b 1575 are equal. (c.f. `np.array_equal`, which checks dtype=object values using 1576 object identity.) 1577 1578 Args: 1579 a: A numpy array. 1580 b: A numpy array. 1581 msg: Message to display if a != b. 1582 """ 1583 if isinstance(a, np.ndarray) and a.dtype == object: 1584 self.assertEqual(a.dtype, b.dtype, msg) 1585 self.assertEqual(a.shape, b.shape, msg) 1586 self.assertLen(a, len(b), msg) 1587 for a_val, b_val in zip(a, b): 1588 self.assertNumpyObjectTensorsRecursivelyEqual(a_val, b_val, msg) 1589 else: 1590 self.assertAllEqual(a, b, msg) 1591 1592 @parameterized.named_parameters([ 1593 ('Shape_2_R', 1594 [[1, 2], [3, 4, 5]], 1595 np.array([int32array([1, 2]), int32array([3, 4, 5])])), 1596 ('Shape_2_2', 1597 [[1, 2], [3, 4]], 1598 np.array([[1, 2], [3, 4]])), 1599 ('Shape_2_R_2', 1600 [[[1, 2], [3, 4]], [[5, 6]]], 1601 np.array([int32array([[1, 2], [3, 4]]), int32array([[5, 6]])])), 1602 ('Shape_3_2_R', 1603 [[[1], []], [[2, 3], [4]], [[], [5, 6, 7]]], 1604 np.array([[int32array([1]), int32array([])], 1605 [int32array([2, 3]), int32array([4])], 1606 [int32array([]), int32array([5, 6, 7])]])), 1607 ('Shape_0_R', 1608 ragged_factory_ops.constant_value([], ragged_rank=1, dtype=np.int32), 1609 np.zeros([0, 0], dtype=np.int32)), 1610 ('Shape_0_R_2', 1611 ragged_factory_ops.constant_value([], ragged_rank=1, 1612 inner_shape=(2,), dtype=np.int32), 1613 np.zeros([0, 0, 2], dtype=np.int32)), 1614 ]) # pyformat: disable 1615 def testRaggedTensorNumpy(self, rt, expected): 1616 if isinstance(rt, list): 1617 rt = ragged_factory_ops.constant(rt, dtype=dtypes.int32) 1618 else: 1619 rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt) 1620 if context.executing_eagerly(): 1621 actual = rt.numpy() 1622 self.assertNumpyObjectTensorsRecursivelyEqual( 1623 expected, actual, 'Expected %r, got %r' % (expected, actual)) 1624 else: 1625 with self.assertRaisesRegex(ValueError, 'only supported in eager mode'): 1626 rt.numpy() 1627 1628 @parameterized.parameters([ 1629 ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), 1630 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), 1631 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), 1632 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), 1633 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), 1634 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), 1635 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), 1636 ([[[1, 2, 3]]], 1, [1, 1, None]), 1637 ([[[1, 2, 3]]], 1, [1, 1, 3]), 1638 ]) 1639 def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape): 1640 rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank) 1641 rt1._set_shape(shape) 1642 rt1.shape.assert_is_compatible_with(shape) 1643 if shape is not None: 1644 self.assertIsNot(rt1.shape.rank, None) 1645 for a, b in zip(rt1.shape, shape): 1646 if b is not None: 1647 self.assertEqual(a, b) 1648 1649 @parameterized.parameters([ 1650 ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), 1651 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), 1652 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), 1653 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), 1654 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), 1655 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), 1656 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), 1657 ([[[1, 2, 3]]], 1, [1, 1, None]), 1658 ([[[1, 2, 3]]], 1, [1, 1, 3]), 1659 ]) 1660 def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape): 1661 rt2 = nest.map_structure( 1662 lambda x: array_ops.placeholder_with_default(x, None), 1663 ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank), 1664 expand_composites=True) 1665 rt2._set_shape(shape) 1666 rt2.shape.assert_is_compatible_with(shape) 1667 if shape is not None: 1668 self.assertIsNot(rt2.shape.rank, None) 1669 for a, b in zip(rt2.shape, shape): 1670 if b is not None: 1671 self.assertEqual(a, b) 1672 1673 def testRaggedTensorSetShapeUniformRowLength(self): 1674 rt = [[[1], [2], [3]], [[4], [5], [6]]] 1675 1676 rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1) 1677 rt1._set_shape([2, 3, 1]) 1678 1679 rt2 = nest.map_structure( 1680 lambda x: array_ops.placeholder_with_default(x, None), 1681 rt1, expand_composites=True) 1682 rt2._set_shape([2, 3, 1]) 1683 1684 def testRaggedTensorSetShapeInconsistentShapeError(self): 1685 rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]], 1686 ragged_rank=1) 1687 self.assertEqual(rt.shape.as_list(), [2, 3, 1]) 1688 with self.assertRaises(ValueError): 1689 rt._set_shape([None, None, 5]) 1690 with self.assertRaisesRegex(ValueError, 'Inconsistent size'): 1691 rt._set_shape([None, 5, None]) 1692 with self.assertRaises(ValueError): 1693 rt._set_shape([5, None, None]) 1694 1695 1696@test_util.run_all_in_graph_and_eager_modes 1697class RaggedTensorSpecTest(test_util.TensorFlowTestCase, 1698 parameterized.TestCase): 1699 1700 def assertAllTensorsEqual(self, list1, list2): 1701 self.assertLen(list1, len(list2)) 1702 for (t1, t2) in zip(list1, list2): 1703 self.assertAllEqual(t1, t2) 1704 1705 def testConstruction(self): 1706 spec1 = RaggedTensorSpec(ragged_rank=1) 1707 self.assertIsNone(spec1._shape.rank) 1708 self.assertEqual(spec1._dtype, dtypes.float32) 1709 self.assertEqual(spec1._row_splits_dtype, dtypes.int64) 1710 self.assertEqual(spec1._ragged_rank, 1) 1711 1712 self.assertIsNone(spec1.shape.rank) 1713 self.assertEqual(spec1.dtype, dtypes.float32) 1714 self.assertEqual(spec1.row_splits_dtype, dtypes.int64) 1715 self.assertEqual(spec1.ragged_rank, 1) 1716 1717 spec2 = RaggedTensorSpec(shape=[None, None, None]) 1718 self.assertEqual(spec2._shape.as_list(), [None, None, None]) 1719 self.assertEqual(spec2._dtype, dtypes.float32) 1720 self.assertEqual(spec2._row_splits_dtype, dtypes.int64) 1721 self.assertEqual(spec2._ragged_rank, 2) 1722 1723 with self.assertRaisesRegex(ValueError, 'Must specify ragged_rank'): 1724 RaggedTensorSpec() 1725 with self.assertRaisesRegex(TypeError, 'ragged_rank must be an int'): 1726 RaggedTensorSpec(ragged_rank=constant_op.constant(1)) 1727 with self.assertRaisesRegex(ValueError, 1728 'ragged_rank must be less than rank'): 1729 RaggedTensorSpec(ragged_rank=2, shape=[None, None]) 1730 1731 def testValueType(self): 1732 spec1 = RaggedTensorSpec(ragged_rank=1) 1733 self.assertEqual(spec1.value_type, RaggedTensor) 1734 spec2 = RaggedTensorSpec(ragged_rank=0) 1735 self.assertEqual(spec2.value_type, ops.Tensor) 1736 1737 @parameterized.parameters([ 1738 (RaggedTensorSpec(ragged_rank=1), 1739 (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64)), 1740 (RaggedTensorSpec(shape=[5, None, None]), 1741 (tensor_shape.TensorShape([5, None, None]), dtypes.float32, 1742 2, dtypes.int64)), 1743 (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.int32), 1744 (tensor_shape.TensorShape([5, None, None]), dtypes.int32, 2, 1745 dtypes.int64)), 1746 (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), 1747 (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int32)), 1748 ]) # pyformat: disable 1749 def testSerialize(self, rt_spec, expected): 1750 serialization = rt_spec._serialize() 1751 # TensorShape has an unconventional definition of equality, so we can't use 1752 # assertEqual directly here. But repr() is deterministic and lossless for 1753 # the expected values, so we can use that instead. 1754 self.assertEqual(repr(serialization), repr(expected)) 1755 1756 @parameterized.parameters([ 1757 (RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), [ 1758 tensor_spec.TensorSpec([5, 3], dtypes.float32), 1759 ]), 1760 (RaggedTensorSpec(ragged_rank=1), [ 1761 tensor_spec.TensorSpec(None, dtypes.float32), 1762 tensor_spec.TensorSpec([None], dtypes.int64) 1763 ]), 1764 (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), [ 1765 tensor_spec.TensorSpec(None, dtypes.float32), 1766 tensor_spec.TensorSpec([None], dtypes.int32), 1767 ]), 1768 (RaggedTensorSpec(ragged_rank=2), [ 1769 tensor_spec.TensorSpec(None, dtypes.float32), 1770 tensor_spec.TensorSpec([None], dtypes.int64), 1771 tensor_spec.TensorSpec([None], dtypes.int64), 1772 ]), 1773 (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.string), [ 1774 tensor_spec.TensorSpec([None], dtypes.string), 1775 tensor_spec.TensorSpec([6], dtypes.int64), 1776 tensor_spec.TensorSpec([None], dtypes.int64), 1777 ]), 1778 ]) 1779 def testComponentSpecs(self, rt_spec, expected): 1780 self.assertEqual(rt_spec._component_specs, expected) 1781 1782 @parameterized.parameters([ 1783 { 1784 'rt_spec': RaggedTensorSpec(ragged_rank=0), 1785 'rt': [1.0, 2.0, 3.0], 1786 'components': [[1.0, 2.0, 3.0]] 1787 }, 1788 { 1789 'rt_spec': RaggedTensorSpec(ragged_rank=1), 1790 'rt': [[1.0, 2.0], [3.0]], 1791 'components': [[1.0, 2.0, 3.0], [0, 2, 3]] 1792 }, 1793 { 1794 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 1795 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]], 1796 'components': [[1.0, 2.0, 3.0, 4.0], [0, 2, 4], [0, 2, 3, 3, 4]] 1797 }, 1798 ]) 1799 def testToFromComponents(self, rt_spec, rt, components): 1800 rt = ragged_factory_ops.constant(rt) 1801 actual_components = rt_spec._to_components(rt) 1802 self.assertAllTensorsEqual(actual_components, components) 1803 rt_reconstructed = rt_spec._from_components(actual_components) 1804 self.assertAllEqual(rt, rt_reconstructed) 1805 1806 @test_util.run_v1_only('RaggedTensorValue is deprecated in v2') 1807 def testFromNumpyComponents(self): 1808 spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32) 1809 rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])]) 1810 self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue) 1811 self.assertAllEqual(rt1, [[1, 2], [3]]) 1812 1813 spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32) 1814 rt2 = spec2._from_components( 1815 [np.array([1, 2, 3]), 1816 np.array([0, 2, 3]), 1817 np.array([0, 0, 2, 3])]) 1818 self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue) 1819 self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]]) 1820 1821 spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32) 1822 rt3 = spec3._from_components([np.array([1, 2, 3])]) 1823 self.assertIsInstance(rt3, np.ndarray) 1824 self.assertAllEqual(rt3, [1, 2, 3]) 1825 1826 @parameterized.parameters([ 1827 RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), 1828 RaggedTensorSpec(ragged_rank=1), 1829 RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), 1830 RaggedTensorSpec(ragged_rank=2, dtype=dtypes.string), 1831 RaggedTensorSpec(shape=[5, None, None]), 1832 ]) 1833 def testFlatTensorSpecs(self, rt_spec): 1834 self.assertEqual(rt_spec._flat_tensor_specs, 1835 [tensor_spec.TensorSpec(None, dtypes.variant)]) 1836 1837 @parameterized.named_parameters([ 1838 { 1839 'testcase_name': 'RaggedRank0', 1840 'rt_spec': RaggedTensorSpec(ragged_rank=0), 1841 'rt': [1.0, 2.0, 3.0], 1842 }, 1843 { 1844 'testcase_name': 'RaggedRank1', 1845 'rt_spec': RaggedTensorSpec(ragged_rank=1), 1846 'rt': [[1.0, 2.0], [3.0]] 1847 }, 1848 { 1849 'testcase_name': 'RaggedRank2', 1850 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 1851 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] 1852 }, 1853 ]) 1854 def testToFromTensorList(self, rt_spec, rt): 1855 rt = ragged_factory_ops.constant(rt) 1856 tensor_list = rt_spec._to_tensor_list(rt) 1857 rt_reconstructed = rt_spec._from_tensor_list(tensor_list) 1858 self.assertAllEqual(rt, rt_reconstructed) 1859 1860 @parameterized.named_parameters([ 1861 # TODO(b/141789000) Test ragged_rank=0 when support is added. 1862 { 1863 'testcase_name': 'RaggedRank1', 1864 'rt_spec': RaggedTensorSpec(ragged_rank=1), 1865 'rt': [[1.0, 2.0], [3.0]] 1866 }, 1867 { 1868 'testcase_name': 'RaggedRank2', 1869 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 1870 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] 1871 }, 1872 ]) 1873 def testToFromBatchedTensorList(self, rt_spec, rt): 1874 rt = ragged_factory_ops.constant(rt) 1875 tensor_list = rt_spec._to_batched_tensor_list(rt) 1876 rt_reconstructed = rt_spec._from_tensor_list(tensor_list) 1877 self.assertAllEqual(rt, rt_reconstructed) 1878 first_row = rt_spec._unbatch()._from_tensor_list( 1879 [t[0] for t in tensor_list]) 1880 self.assertAllEqual(rt[0], first_row) 1881 1882 def testToFromBatchedTensorListPreservesUniformRowLengths(self): 1883 rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]), 1884 ragged_rank=2) 1885 rt_spec = rt._type_spec 1886 tensor_list = rt_spec._to_batched_tensor_list(rt) 1887 rt_reconstructed = rt_spec._from_tensor_list(tensor_list) 1888 self.assertAllEqual(rt, rt_reconstructed) 1889 self.assertTrue(rt.shape.is_fully_defined()) 1890 self.assertTrue(rt_reconstructed.shape.is_fully_defined()) 1891 self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list()) 1892 1893 @parameterized.parameters([ 1894 (RaggedTensorSpec([2, None], dtypes.float32, 1), 32, 1895 RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), 1896 (RaggedTensorSpec([4, None], dtypes.float32, 1), None, 1897 RaggedTensorSpec([None, 4, None], dtypes.float32, 2)), 1898 (RaggedTensorSpec([2], dtypes.float32, 1899 -1), 32, RaggedTensorSpec([32, 2], dtypes.float32, 0)), 1900 ]) 1901 def testBatch(self, spec, batch_size, expected): 1902 self.assertEqual(spec._batch(batch_size), expected) 1903 1904 @parameterized.parameters([ 1905 (RaggedTensorSpec([32, None, None], dtypes.float32, 2), 1906 RaggedTensorSpec([None, None], dtypes.float32, 1)), 1907 (RaggedTensorSpec([None, None, None], dtypes.float32, 2), 1908 RaggedTensorSpec([None, None], dtypes.float32, 1)), 1909 (RaggedTensorSpec([32, 2], dtypes.float32, 0), 1910 RaggedTensorSpec([2], dtypes.float32, -1)), 1911 (RaggedTensorSpec([32, None, 4], dtypes.float32, 1, dtypes.int32), 1912 RaggedTensorSpec([None, 4], dtypes.float32, 0, dtypes.int32)), 1913 ]) # pyformat: disable 1914 def testUnbatch(self, spec, expected): 1915 self.assertEqual(spec._unbatch(), expected) 1916 1917 def testIsCompatibleWith(self): 1918 spec1 = RaggedTensorSpec([32, None, None], dtypes.float32, 2) 1919 spec2 = RaggedTensorSpec(None, dtypes.float32, 2) 1920 spec3 = RaggedTensorSpec(None, dtypes.int32, 1) 1921 spec4 = RaggedTensorSpec([None], dtypes.int32, 0) 1922 1923 self.assertTrue(spec1.is_compatible_with(spec2)) 1924 self.assertFalse(spec1.is_compatible_with(spec3)) 1925 self.assertFalse(spec1.is_compatible_with(spec4)) 1926 self.assertFalse(spec2.is_compatible_with(spec3)) 1927 self.assertFalse(spec2.is_compatible_with(spec4)) 1928 self.assertFalse(spec3.is_compatible_with(spec4)) 1929 self.assertTrue(spec4.is_compatible_with(constant_op.constant([1, 2, 3]))) 1930 1931 1932if __name__ == '__main__': 1933 googletest.main() 1934