1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""RaggedKerasTensor tests.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import func_graph 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.keras import keras_parameterized 27from tensorflow.python.keras import layers 28from tensorflow.python.keras.engine import training 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.ragged import ragged_factory_ops 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.platform import test 33from tensorflow.python.util import nest 34 35 36class RaggedKerasTensorTest(keras_parameterized.TestCase): 37 38 @parameterized.parameters( 39 {'batch_size': None, 'shape': (None, 5), 'ragged_rank': 1}, 40 {'batch_size': None, 'shape': (None, 3, 5), 'ragged_rank': 1}, 41 {'batch_size': None, 'shape': (5, None), 'ragged_rank': 2}, 42 {'batch_size': None, 'shape': (3, 5, None), 'ragged_rank': 3}, 43 {'batch_size': None, 'shape': (None, 3, 5, None), 'ragged_rank': 4}, 44 {'batch_size': None, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6}, 45 {'batch_size': 8, 'shape': (None, 5), 'ragged_rank': 1}, 46 {'batch_size': 9, 'shape': (None, 3, 5), 'ragged_rank': 1}, 47 {'batch_size': 1, 'shape': (5, None), 'ragged_rank': 2}, 48 {'batch_size': 4, 'shape': (3, 5, None), 'ragged_rank': 3}, 49 {'batch_size': 7, 'shape': (None, 3, 5, None), 'ragged_rank': 4}, 50 {'batch_size': 12, 'shape': (2, 3, None, 4, 5, None), 'ragged_rank': 6}, 51 ) 52 def test_to_placeholder(self, shape, batch_size, ragged_rank): 53 inp = layers.Input(shape=shape, batch_size=batch_size, ragged=True) 54 self.assertEqual(inp.ragged_rank, ragged_rank) 55 self.assertAllEqual(inp.shape, [batch_size] + list(shape)) 56 with func_graph.FuncGraph('test').as_default(): 57 placeholder = inp._to_placeholder() 58 self.assertEqual(placeholder.ragged_rank, ragged_rank) 59 self.assertAllEqual(placeholder.shape, [batch_size] + list(shape)) 60 61 def test_add(self): 62 inp = layers.Input(shape=[None], ragged=True) 63 out = inp + inp 64 model = training.Model(inp, out) 65 66 x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) 67 self.assertAllEqual(model(x), x + x) 68 69 def test_mul(self): 70 inp = layers.Input(shape=[None], ragged=True) 71 out = inp * inp 72 model = training.Model(inp, out) 73 74 x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) 75 self.assertAllEqual(model(x), x * x) 76 77 def test_sub(self): 78 inp = layers.Input(shape=[None], ragged=True) 79 out = inp - inp 80 model = training.Model(inp, out) 81 82 x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) 83 self.assertAllEqual(model(x), x - x) 84 85 def test_div(self): 86 inp = layers.Input(shape=[None], ragged=True) 87 out = inp / inp 88 model = training.Model(inp, out) 89 90 x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) 91 self.assertAllEqual(model(x), x / x) 92 93 def test_getitem(self): 94 # Test slicing / getitem 95 inp = layers.Input(shape=(None, 2), ragged=True) 96 out = inp[:, :2] 97 model = training.Model(inp, out) 98 99 x = ragged_tensor.RaggedTensor.from_row_lengths( 100 math_ops.cast(np.random.randn(6, 2), dtype=dtypes.float32), [3, 1, 2]) 101 expected = x[:, :2] 102 103 self.assertAllEqual(model(x), expected) 104 105 # Test that models w/ slicing are correctly serialized/deserialized 106 config = model.get_config() 107 model = training.Model.from_config(config) 108 109 self.assertAllEqual(model(x), expected) 110 111 @parameterized.parameters( 112 {'property_name': 'values'}, 113 {'property_name': 'flat_values'}, 114 {'property_name': 'row_splits'}, 115 {'property_name': 'nested_row_splits'}, 116 ) 117 def test_instance_property(self, property_name): 118 inp = layers.Input(shape=[None], ragged=True) 119 out = getattr(inp, property_name) 120 model = training.Model(inp, out) 121 122 x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) 123 expected_property = getattr(x, property_name) 124 self.assertAllEqual(model(x), expected_property) 125 126 # Test that it works with serialization and deserialization as well 127 model_config = model.get_config() 128 model2 = training.Model.from_config(model_config) 129 self.assertAllEqual(model2(x), expected_property) 130 131 @parameterized.parameters( 132 {'name': 'value_rowids'}, 133 {'name': 'nested_value_rowids'}, 134 {'name': 'nrows'}, 135 {'name': 'row_starts'}, 136 {'name': 'row_limits'}, 137 {'name': 'row_lengths'}, 138 {'name': 'nested_row_lengths'}, 139 {'name': 'bounding_shape'}, 140 { 141 'name': 'with_values', 142 'args': [[1, 2, 3, 4, 5, 6]] 143 }, 144 { 145 'name': 'with_flat_values', 146 'kwargs': { 147 'new_values': [1, 2, 3, 4, 5, 6] 148 } 149 }, 150 { 151 'name': 'with_row_splits_dtype', 152 'kwargs': { 153 'dtype': dtypes.int32 154 } 155 }, 156 { 157 'name': 'merge_dims', 158 'args': [0], 159 'kwargs': { 160 'inner_axis': 1 161 } 162 }, 163 {'name': 'to_tensor'}, 164 {'name': 'to_sparse'}, 165 ) 166 def test_instance_method(self, name, args=None, kwargs=None): 167 if not args: 168 args = [] 169 if not kwargs: 170 kwargs = {} 171 172 inp = layers.Input(shape=[None], ragged=True) 173 out = getattr(inp, name)(*args, **kwargs) 174 model = training.Model(inp, out) 175 176 x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) 177 expected_property = getattr(x, name)(*args, **kwargs) 178 # We expand composites before checking equality because 179 # assertAllEqual otherwise wouldn't work for SparseTensor outputs 180 for a, b in zip(nest.flatten(model(x), expand_composites=True), 181 nest.flatten(expected_property, expand_composites=True)): 182 self.assertAllEqual(a, b) 183 184 # Test that the model can serialize and deserialize as well 185 model_config = model.get_config() 186 model2 = training.Model.from_config(model_config) 187 for a, b in zip(nest.flatten(model2(x), expand_composites=True), 188 nest.flatten(expected_property, expand_composites=True)): 189 self.assertAllEqual(a, b) 190 191 192class RaggedTensorClassMethodAsLayerTest(keras_parameterized.TestCase): 193 194 def test_from_value_rowids(self): 195 inp = layers.Input(shape=[None]) 196 out = ragged_tensor.RaggedTensor.from_value_rowids( 197 inp, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 198 model = training.Model(inp, out) 199 200 x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) 201 expected = ragged_tensor.RaggedTensor.from_value_rowids( 202 x, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 203 self.assertAllEqual(model(x), expected) 204 205 # Test that the model can serialize and deserialize as well 206 model_config = model.get_config() 207 model2 = training.Model.from_config(model_config) 208 self.assertAllEqual(model2(x), expected) 209 210 def test_from_row_splits(self): 211 inp = layers.Input(shape=[None]) 212 out = ragged_tensor.RaggedTensor.from_row_splits( 213 inp, row_splits=[0, 4, 4, 7, 8, 8]) 214 model = training.Model(inp, out) 215 216 x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) 217 expected = ragged_tensor.RaggedTensor.from_row_splits( 218 x, row_splits=[0, 4, 4, 7, 8, 8]) 219 self.assertAllEqual(model(x), expected) 220 221 # Test that the model can serialize and deserialize as well 222 model_config = model.get_config() 223 model2 = training.Model.from_config(model_config) 224 self.assertAllEqual(model2(x), expected) 225 226 def test_from_row_lengths(self): 227 inp = layers.Input(shape=[None]) 228 out = ragged_tensor.RaggedTensor.from_row_lengths( 229 inp, row_lengths=[4, 0, 3, 1, 0]) 230 model = training.Model(inp, out) 231 232 x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) 233 expected = ragged_tensor.RaggedTensor.from_row_lengths( 234 x, row_lengths=[4, 0, 3, 1, 0]) 235 self.assertAllEqual(model(x), expected) 236 237 # Test that the model can serialize and deserialize as well 238 model_config = model.get_config() 239 model2 = training.Model.from_config(model_config) 240 self.assertAllEqual(model2(x), expected) 241 242 def test_from_row_starts(self): 243 inp = layers.Input(shape=[None]) 244 out = ragged_tensor.RaggedTensor.from_row_starts( 245 inp, row_starts=[0, 4, 4, 7, 8]) 246 model = training.Model(inp, out) 247 248 x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) 249 expected = ragged_tensor.RaggedTensor.from_row_starts( 250 x, row_starts=[0, 4, 4, 7, 8]) 251 self.assertAllEqual(model(x), expected) 252 253 # Test that the model can serialize and deserialize as well 254 model_config = model.get_config() 255 model2 = training.Model.from_config(model_config) 256 self.assertAllEqual(model2(x), expected) 257 258 def test_from_row_limits(self): 259 row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) 260 261 inp = layers.Input(shape=[None], dtype=dtypes.string) 262 out = ragged_tensor.RaggedTensor.from_row_limits( 263 inp, row_limits, validate=False) 264 model = training.Model(inp, out) 265 266 x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 267 expected = ragged_tensor.RaggedTensor.from_row_limits( 268 x, row_limits, validate=False) 269 self.assertAllEqual(model(x), expected) 270 271 # Test that the model can serialize and deserialize as well 272 model_config = model.get_config() 273 model2 = training.Model.from_config(model_config) 274 self.assertAllEqual(model2(x), expected) 275 276 def test_from_uniform_row_length(self): 277 inp = layers.Input(shape=[None]) 278 out = ragged_tensor.RaggedTensor.from_uniform_row_length(inp, 2, 8) 279 model = training.Model(inp, out) 280 281 x = constant_op.constant( 282 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) 283 expected = ragged_tensor.RaggedTensor.from_uniform_row_length(x, 2, 8) 284 self.assertAllEqual(model(x), expected) 285 286 # Test that the model can serialize and deserialize as well 287 model_config = model.get_config() 288 model2 = training.Model.from_config(model_config) 289 self.assertAllEqual(model2(x), expected) 290 291 def test_from_nested_value_row_ids(self): 292 nested_value_rowids = [ 293 constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), 294 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 295 ] 296 inp = layers.Input(shape=[None], dtype=dtypes.string) 297 out = ragged_tensor.RaggedTensor.from_nested_value_rowids( 298 inp, nested_value_rowids) 299 model = training.Model(inp, out) 300 301 x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 302 expected = ragged_tensor.RaggedTensor.from_nested_value_rowids( 303 x, nested_value_rowids) 304 self.assertAllEqual(model(x), expected) 305 306 # Test that the model can serialize and deserialize as well 307 model_config = model.get_config() 308 model2 = training.Model.from_config(model_config) 309 self.assertAllEqual(model2(x), expected) 310 311 def test_from_nested_row_splits(self): 312 nested_row_splits = [ 313 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 314 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 315 ] 316 inp = layers.Input(shape=[None], dtype=dtypes.string) 317 out = ragged_tensor.RaggedTensor.from_nested_row_splits( 318 inp, nested_row_splits) 319 model = training.Model(inp, out) 320 321 x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 322 expected = ragged_tensor.RaggedTensor.from_nested_row_splits( 323 x, nested_row_splits) 324 self.assertAllEqual(model(x), expected) 325 326 # Test that the model can serialize and deserialize as well 327 model_config = model.get_config() 328 model2 = training.Model.from_config(model_config) 329 self.assertAllEqual(model2(x), expected) 330 331 def test_from_nested_row_lengths(self): 332 nested_row_lengths = [ 333 constant_op.constant([2, 1, 0, 2], dtypes.int64), 334 constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) 335 ] 336 inp = layers.Input(shape=[None], dtype=dtypes.string) 337 out = ragged_tensor.RaggedTensor.from_nested_row_lengths( 338 inp, nested_row_lengths) 339 model = training.Model(inp, out) 340 341 x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 342 expected = ragged_tensor.RaggedTensor.from_nested_row_lengths( 343 x, nested_row_lengths) 344 self.assertAllEqual(model(x), expected) 345 346 # Test that the model can serialize and deserialize as well 347 model_config = model.get_config() 348 model2 = training.Model.from_config(model_config) 349 self.assertAllEqual(model2(x), expected) 350 351 def test_from_tensor(self): 352 inp = layers.Input(shape=[None], ragged=False) 353 out = ragged_tensor.RaggedTensor.from_tensor(inp) 354 model = training.Model(inp, out) 355 356 x = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]]) 357 expected = ragged_tensor.RaggedTensor.from_tensor(x) 358 self.assertAllEqual(model(x), expected) 359 360 # Test that the model can serialize and deserialize as well 361 model_config = model.get_config() 362 model2 = training.Model.from_config(model_config) 363 self.assertAllEqual(model2(x), expected) 364 365 def test_from_sparse(self): 366 inp = layers.Input(shape=[None], sparse=True, dtype=dtypes.string) 367 out = ragged_tensor.RaggedTensor.from_sparse(inp) 368 model = training.Model(inp, out) 369 370 indices = [[0, 0], [1, 0], [1, 1], [2, 0]] 371 values = [b'a', b'b', b'c', b'd'] 372 shape = [4, 5] 373 sp_value = sparse_tensor.SparseTensor(indices, values, shape) 374 375 expected = ragged_tensor.RaggedTensor.from_sparse(sp_value) 376 self.assertAllEqual(model(sp_value), expected) 377 378 # Test that the model can serialize and deserialize as well 379 model_config = model.get_config() 380 model2 = training.Model.from_config(model_config) 381 self.assertAllEqual(model2(sp_value), expected) 382 383 384if __name__ == '__main__': 385 ops.enable_eager_execution() 386 tensor_shape.enable_v2_tensorshape() 387 test.main() 388