1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for ragged tensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import gen_ragged_math_ops 29from tensorflow.python.ops import map_fn 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops.ragged import ragged_functional_ops 32from tensorflow.python.ops.ragged import ragged_tensor 33from tensorflow.python.ops.ragged import segment_id_ops 34from tensorflow.python.util import dispatch 35from tensorflow.python.util.tf_export import tf_export 36 37 38#=============================================================================== 39# ragged.range 40#=============================================================================== 41# pylint: disable=redefined-builtin 42@tf_export('ragged.range') 43@dispatch.add_dispatch_support 44def range(starts, 45 limits=None, 46 deltas=1, 47 dtype=None, 48 name=None, 49 row_splits_dtype=dtypes.int64): 50 """Returns a `RaggedTensor` containing the specified sequences of numbers. 51 52 Each row of the returned `RaggedTensor` contains a single sequence: 53 54 ```python 55 ragged.range(starts, limits, deltas)[i] == 56 tf.range(starts[i], limits[i], deltas[i]) 57 ``` 58 59 If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an 60 empty list. Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then 61 `output[i]` will be an empty list. This behavior is consistent with the 62 Python `range` function, but differs from the `tf.range` op, which returns 63 an error for these cases. 64 65 Examples: 66 67 >>> tf.ragged.range([3, 5, 2]).to_list() 68 [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]] 69 >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list() 70 [[0, 1, 2], [], [8, 9, 10, 11]] 71 >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list() 72 [[0, 2], [], [8, 10]] 73 74 The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. 75 The vector inputs must all have the same size. Scalar inputs are broadcast 76 to match the size of the vector inputs. 77 78 Args: 79 starts: Vector or scalar `Tensor`. Specifies the first entry for each range 80 if `limits` is not `None`; otherwise, specifies the range limits, and the 81 first entries default to `0`. 82 limits: Vector or scalar `Tensor`. Specifies the exclusive upper limits for 83 each range. 84 deltas: Vector or scalar `Tensor`. Specifies the increment for each range. 85 Defaults to `1`. 86 dtype: The type of the elements of the resulting tensor. If not specified, 87 then a value is chosen based on the other args. 88 name: A name for the operation. 89 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 90 tensor. One of `tf.int32` or `tf.int64`. 91 92 Returns: 93 A `RaggedTensor` of type `dtype` with `ragged_rank=1`. 94 """ 95 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 96 if limits is None: 97 starts, limits = 0, starts 98 99 with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name: 100 starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts') 101 limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits') 102 deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas') 103 104 # infer dtype if not explicitly provided 105 if dtype is None: 106 starts, limits, deltas = _infer_matching_dtype( 107 [starts, limits, deltas], 108 [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) 109 110 result = gen_ragged_math_ops.ragged_range( 111 starts, limits, deltas, Tsplits=row_splits_dtype, name=name) 112 return ragged_tensor.RaggedTensor.from_row_splits( 113 result.rt_dense_values, result.rt_nested_splits, validate=False) 114 115 116def _infer_matching_dtype(tensors, dtype_hierarchy): 117 """Infers a matching dtype for tensors, and casts them to that dtype.""" 118 assert all(t.dtype in dtype_hierarchy for t in tensors) 119 inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index) 120 return [math_ops.cast(t, inferred_dtype) for t in tensors] 121 122 123ops.no_gradient('RaggedRange') 124 125#=============================================================================== 126# ragged_segment_<AGGREGATE> 127#=============================================================================== 128 129# Docstring template used for the raggged_segment_<AGGREGATE> ops. 130_RAGGED_SEGMENT_DOCSTRING = """\ 131Computes the %(combination)s along segments of a RaggedTensor. 132 133 Returns a RaggedTensor `output` with `num_segments` rows, where the row 134 `output[i]` is formed by taking the %(combination)s of all rows of `data` 135 whose corresponding `segment_id` is `i`. 136 137 The length of the row `output[i]` will be the maximum of the lengths of 138 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 139 rows correspond to a given segment ID, then the output row for that segment 140 ID will be empty. 141 142 Args: 143 data: A `RaggedTensor` containing the values to combine. 144 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 145 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 146 Must be greater than or equal to zero, and less than `num_segments`. 147 `segment_ids` is not required to be sorted. 148 num_segments: An `int32` or `int64` scalar specifying the number of 149 distinct segment ids. 150 name: A name prefix for the returned tensor (optional). 151 Returns: 152 A `RaggedTensor` containing the %(combined)s values. The returned tensor 153 has the same dtype as `data`, and its shape is 154 `[num_segments] + data.shape[segment_ids.rank:]`. 155 Raises: 156 ValueError: If `segment_ids.shape` is not a prefix of `data.shape`. 157""" 158 159 160def _ragged_segment_aggregate(unsorted_segment_op, 161 data, 162 segment_ids, 163 num_segments, 164 separator=None, 165 name=None): 166 """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. 167 168 Returns a RaggedTensor `output` with `num_segments` rows, where the row 169 `output[i]` is formed by combining all rows of `data` whose corresponding 170 `segment_id` is `i`. The values in each row are combined using 171 `unsorted_segment_op`. 172 173 The length of the row `output[i]` will be the maximum of the lengths of 174 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 175 rows correspond to a given segment ID, then the output row for that segment 176 ID will be empty. 177 178 Args: 179 unsorted_segment_op: The tensorflow `op` that should be used to combine 180 values in each row. Must have the same signature and basic behavior as 181 `unsorted_segment_sum`, `unsorted_segment_max`, etc. 182 data: A `RaggedTensor` containing the values to be combined. 183 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 184 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 185 `segment_ids` is not required to be sorted. 186 num_segments: An `int32` or `int64` scalar. 187 separator: An optional string. Defaults to None. The separator to use when 188 joining. Only used for string types. 189 name: A name prefix for the returned tensor (optional). 190 191 Returns: 192 A `RaggedTensor` containing the aggregated values. The returned tensor 193 has the same dtype as `data`, and its shape is 194 `[num_segments] + data.shape[segment_ids.rank:]`. 195 Raises: 196 ValueError: If segment_ids.shape is not a prefix of data.shape. 197 """ 198 if not (ragged_tensor.is_ragged(data) or 199 ragged_tensor.is_ragged(segment_ids)): 200 if separator is not None: 201 # It uses unsorted_segment_join. 202 return unsorted_segment_op(data, segment_ids, num_segments, separator, 203 name) 204 else: 205 return unsorted_segment_op(data, segment_ids, num_segments, name) 206 207 with ops.name_scope(name, 'RaggedSegment', 208 [data, segment_ids, num_segments]) as name: 209 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 210 segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( 211 segment_ids, name='segment_ids') 212 data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids) 213 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 214 raise ValueError('segment_ids must have dtype int32 or int64.') 215 216 if ragged_tensor.is_ragged(segment_ids): 217 if not ragged_tensor.is_ragged(data): 218 raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 219 'but segment_ids is ragged and data is not.') 220 check_splits = check_ops.assert_equal( 221 segment_ids.row_splits, 222 data.row_splits, 223 message='segment_ids.shape must be a prefix of data.shape') 224 with ops.control_dependencies([check_splits]): 225 return _ragged_segment_aggregate(unsorted_segment_op, data.values, 226 segment_ids.values, num_segments, 227 separator) 228 229 # Find the length of each row in data. (shape=[data_nrows]) 230 data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] 231 232 # Find the length that each output row will have. The length of the row 233 # corresponding to segment `id` is `max(data_row_lengths[i])` where 234 # `segment_ids[i]=id`. (shape=[output_nrows]) 235 output_row_lengths = math_ops.maximum( 236 math_ops.unsorted_segment_max(data_row_lengths, segment_ids, 237 num_segments), 0) 238 239 # Build the splits tensor for the output RaggedTensor. 240 output_splits = array_ops.concat([ 241 array_ops.zeros([1], output_row_lengths.dtype), 242 math_ops.cumsum(output_row_lengths) 243 ], 244 axis=0) 245 246 # For each row in `data`, find the start & limit position where that row's 247 # values will be aggregated in output.values. 248 data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) 249 data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths 250 251 # For each value in `data.values`, find the position where it will 252 # aggregated in `output.values`. 253 # Get the target output values index for each data values index. 254 data_val_to_out_val_index = range(data_row_to_out_row_start, 255 data_row_to_out_row_limit).values 256 257 # Recursively aggregate the values. 258 output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, 259 data_val_to_out_val_index, 260 output_splits[-1], separator) 261 return ragged_tensor.RaggedTensor.from_row_splits( 262 output_values, output_splits, validate=False) 263 264 265def segment_sum(data, segment_ids, num_segments, name=None): 266 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 267 return _ragged_segment_aggregate( 268 math_ops.unsorted_segment_sum, 269 data=data, 270 segment_ids=segment_ids, 271 num_segments=num_segments, 272 name=(name or 'RaggedSegmentSum')) 273 274 275def segment_prod(data, segment_ids, num_segments, name=None): 276 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 277 return _ragged_segment_aggregate( 278 math_ops.unsorted_segment_prod, 279 data=data, 280 segment_ids=segment_ids, 281 num_segments=num_segments, 282 name=(name or 'RaggedSegmentProd')) 283 284 285def segment_min(data, segment_ids, num_segments, name=None): 286 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 287 return _ragged_segment_aggregate( 288 math_ops.unsorted_segment_min, 289 data=data, 290 segment_ids=segment_ids, 291 num_segments=num_segments, 292 name=(name or 'RaggedSegmentMin')) 293 294 295def segment_max(data, segment_ids, num_segments, name=None): 296 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 297 return _ragged_segment_aggregate( 298 math_ops.unsorted_segment_max, 299 data=data, 300 segment_ids=segment_ids, 301 num_segments=num_segments, 302 name=(name or 'RaggedSegmentMax')) 303 304 305def segment_mean(data, segment_ids, num_segments, name=None): 306 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 307 with ops.name_scope(name, 'RaggedSegmentMean', 308 [data, segment_ids, num_segments]): 309 total = segment_sum(data, segment_ids, num_segments) 310 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 311 array_ops.ones_like(data.flat_values), 312 data.nested_row_splits, 313 validate=False) 314 count = segment_sum(ones, segment_ids, num_segments) 315 if ragged_tensor.is_ragged(total): 316 return total.with_flat_values(total.flat_values / count.flat_values) 317 else: 318 return total / count 319 320 321def segment_sqrt_n(data, segment_ids, num_segments, name=None): 322 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 323 with ops.name_scope(name, 'RaggedSegmentSqrtN', 324 [data, segment_ids, num_segments]): 325 total = segment_sum(data, segment_ids, num_segments) 326 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 327 array_ops.ones_like(data.flat_values), 328 data.nested_row_splits, 329 validate=False) 330 count = segment_sum(ones, segment_ids, num_segments) 331 if ragged_tensor.is_ragged(total): 332 return total.with_flat_values(total.flat_values / 333 math_ops.sqrt(count.flat_values)) 334 else: 335 return total / math_ops.sqrt(count) 336 337 338def _set_ragged_segment_docstring(func, combination, combined): 339 func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict( 340 combination=combination, combined=combined) 341 342 343_set_ragged_segment_docstring(segment_sum, 'sum', 'summed') 344_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied') 345_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized') 346_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized') 347_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged') 348_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', 349 'summed') 350 351#=============================================================================== 352# ragged_reduce_<AGGREGATE> 353#=============================================================================== 354 355# Docstring template used for ragged_reduce_<AGGREGATE> ops. 356_RAGGED_REDUCE_DOCSTRING = """\ 357Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. 358 359 Reduces `input_tensor` along the dimensions given in `axis` by taking the 360 %(combination)s of values. If a reduced dimension has no elements for 361 some index, then the value for that index will be %(default)s. 362 363 The rank of the tensor is reduced by `1` for each entry in `axis`. If 364 `axis` is not specified, then all dimensions are reduced, and a scalar 365 value is returned. 366 Args: 367 input_tensor: A `RaggedTensor` containing the values to be %(combined)s. 368 axis: The dimensions to reduce. May be `None` (to reduce all axes), an 369 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce 370 a given set of axes), or a `Tensor` with a constant value. Must be in 371 the range `[0, input_tensor.rank]`. 372 name: A name prefix for the returned tensor (optional). 373 Returns: 374 A `RaggedTensor` containing the %(combined)s values. The returned tensor 375 has the same dtype as `data`, and its shape is given by removing the 376 dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank` 377 of the returned tensor is given by substracting any ragged dimensions 378 specified in `axis` from `input_tensor.ragged_rank`. 379 Raises: 380 ValueError: If `axis` contains a `Tensor` whose value is not constant. 381 ####Example: 382 %(example)s 383""" 384_RAGGED_REDUCE_SUM_EXAMPLE = """ 385 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 386 >>> tf.reduce_sum(rt, axis=0).numpy() # = [3+1+9+2, 1+5+6, 4] 387 array([15, 12, 4], dtype=int32) 388 >>> tf.reduce_sum(rt, axis=1).numpy() # = [3+1+4, 1+5, 9, 2+6] 389 array([8, 6, 9, 8], dtype=int32) 390""" 391_RAGGED_REDUCE_PROD_EXAMPLE = """ 392 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 393 >>> tf.reduce_prod(rt, axis=0).numpy() # = [3*1*9*2, 1*5*6, 4] 394 array([54, 30, 4], dtype=int32) 395 >>> tf.reduce_prod(rt, axis=1).numpy() # = [3*1*4, 1*5, 9, 2*6] 396 array([12, 5, 9, 12], dtype=int32) 397""" 398_RAGGED_REDUCE_MIN_EXAMPLE = """ 399 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 400 >>> tf.reduce_min(rt, axis=0).numpy() 401 array([1, 1, 4], dtype=int32) 402 >>> tf.reduce_min(rt, axis=1).numpy() 403 array([1, 1, 9, 2], dtype=int32) 404""" 405_RAGGED_REDUCE_MAX_EXAMPLE = """ 406 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 407 >>> tf.reduce_max(rt, axis=0).numpy() 408 array([9, 6, 4], dtype=int32) 409 >>> tf.reduce_max(rt, axis=1).numpy() 410 array([4, 5, 9, 6], dtype=int32) 411""" 412_RAGGED_REDUCE_MEAN_EXAMPLE = """ 413 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 414 >>> tf.reduce_mean(rt, axis=0).numpy() 415 array([3.75, 4. , 4. ]) 416 >>> tf.reduce_mean(rt, axis=1).numpy() 417 array([2.66666667, 3. , 9. , 4. ]) 418""" 419_RAGGED_REDUCE_VARIANCE_EXAMPLE = """ 420 >>> rt = tf.ragged.constant([[1, 1, 4], [2, 1], [3], [4, 1]], 421 ... dtype=tf.float64) 422 >>> tf.math.reduce_variance(rt, axis=0).numpy() 423 array([1.25, 0., 0.]) 424 >>> tf.math.reduce_variance(rt, axis=1).numpy() 425 array([2., 0.25, 0., 2.25]) 426""" 427_RAGGED_REDUCE_STD_EXAMPLE = """ 428 >>> rt = tf.ragged.constant([[1, 0], [2, 1], [3], [4, 1]], 429 ... dtype=tf.float64) 430 >>> tf.math.reduce_std(rt, axis=0).numpy() 431 array([1.11803399, 0.47140452]) 432 >>> tf.math.reduce_std(rt, axis=1).numpy() 433 array([0.5, 0.5, 0., 1.5]) 434""" 435_RAGGED_REDUCE_ALL_EXAMPLE = """ 436 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 437 >>> tf.reduce_all(rt, axis=0).numpy() 438 array([False, True, False, True]) 439 >>> tf.reduce_all(rt, axis=1).numpy() 440 array([ True, False, False]) 441""" 442_RAGGED_REDUCE_ANY_EXAMPLE = """ 443 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 444 >>> tf.reduce_any(rt, axis=0).numpy() 445 array([ True, True, False, True]) 446 >>> tf.reduce_any(rt, axis=1).numpy() 447 array([ True, True, True]) 448""" 449 450 451def ragged_reduce_aggregate(reduce_op, 452 unsorted_segment_op, 453 rt_input, 454 axis, 455 keepdims, 456 separator=None, 457 name=None): 458 """Aggregates across axes of a RaggedTensor using the given `Tensor` ops. 459 460 Reduces `rt_input` along the dimensions given in `axis`. The rank of the 461 tensor is reduced by 1 for each entry in `axis`. If `axis` is not specified, 462 then all dimensions are reduced, and a scalar value is returned. 463 464 This op assumes that `reduce_op` and `unsorted_segment_op` are associative; 465 if not, then reducing multiple axes will return incorrect results. (In 466 particular, reducing multiple axes is currently implemented by reducing the 467 axes one at a time.) 468 469 Args: 470 reduce_op: The tensorflow `op` that should be used to reduce values in 471 uniform dimensions. Must have the same signature and basic behavior as 472 `reduce_sum`, `reduce_max`, etc. 473 unsorted_segment_op: The tensorflow `op` that should be used to combine 474 values in ragged dimensions. Must have the same signature and basic 475 behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. 476 rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced. 477 axis: The axis or axes to reduce. May be `None` (to reduce all axes), an 478 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a 479 given set of axes), or a `Tensor` with a constant value. Must be in the 480 range `[0, rt_input.rank)`. 481 keepdims: If true, retains reduced dimensions with length 1. 482 separator: An optional string. Defaults to None. The separator to use when 483 joining. The separator must not be set for non-string data types. (i.e. if 484 separator is not None then it uses string ops) 485 name: A name prefix for the returned tensor (optional). 486 487 Returns: 488 A `RaggedTensor` containing the reduced values. The returned tensor 489 has the same dtype as `data`, and its shape is given by removing the 490 dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` 491 of the returned tensor is given by substracting any ragged dimensions 492 specified in `axis` from `rt_input.ragged_rank`. 493 Raises: 494 ValueError: If `axis` contains a `Tensor` whose value is not constant. 495 """ 496 if not ragged_tensor.is_ragged(rt_input): 497 if separator is None: 498 return reduce_op(rt_input, axis, keepdims=keepdims, name=name) 499 else: 500 # When separator is not None, We infer that dtype is string and 501 # reduce_join will be called. 502 return reduce_op( 503 rt_input, axis, keepdims=keepdims, name=name, separator=separator) 504 505 if isinstance(axis, ops.Tensor): 506 axis = tensor_util.constant_value(axis) 507 if axis is None: 508 raise ValueError('axis must be known at graph construction time.') 509 if isinstance(axis, np.ndarray): 510 axis = axis.tolist() 511 512 # When reducing all axes, just ignore splits & reduce the inner values. 513 if axis is None: 514 result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name) 515 if keepdims: 516 # Expand the result to the input number of dimensions. 517 for _ in rt_input.shape[1:]: 518 result = array_ops.expand_dims(result, axis=0) 519 return result 520 521 with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]): 522 if isinstance(axis, (tuple, list)): 523 if not axis: 524 return rt_input 525 elif len(axis) == 1: 526 axis = axis[0] 527 else: 528 # When reducing multiple axes, as we reduce one at a time (see below), 529 # the negative axis has to be converted to positive at the first run 530 # as the sort with negative axis will have different orders. 531 # See GitHub issue 27497. 532 axis = [ 533 array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i, 534 'rank(input_tensor)') 535 for i, a in enumerate(axis) 536 ] 537 # When reducing multiple axes, just reduce one at a time. This is less 538 # efficient, and only works for associative ops. (In particular, it 539 # does not work for reduce_mean.) However, reducing multiple axes at 540 # once will probably require a nontrivial c++ op. 541 axis = sorted(axis) 542 inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 543 rt_input, axis[-1], keepdims, 544 separator) 545 return ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 546 inner_reduced, axis[:-1], keepdims, 547 separator) 548 549 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 550 rt_input, name='rt_input') 551 552 axis = array_ops.get_positive_axis( 553 axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)') 554 555 if axis == 0: 556 # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N] 557 row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1] 558 num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0) 559 segment_ids = range(row_lengths).values 560 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 561 segment_ids, num_segments, separator) 562 if keepdims: 563 result = array_ops.expand_dims(result, axis=0) 564 return result 565 elif axis == 1: 566 # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N] 567 num_segments = array_ops.shape(rt_input.row_splits)[0] - 1 568 segment_ids = segment_id_ops.row_splits_to_segment_ids( 569 rt_input.row_splits) 570 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 571 segment_ids, num_segments, separator) 572 if keepdims: 573 result = array_ops.expand_dims(result, axis=1) 574 return result 575 else: 576 # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] = 577 # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] 578 return rt_input.with_values( 579 ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 580 rt_input.values, axis - 1, keepdims, 581 separator)) 582 583 584def reduce_sum(input_tensor, axis=None, keepdims=None, name=None): 585 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 586 587 return ragged_reduce_aggregate( 588 reduce_op=math_ops.reduce_sum, 589 unsorted_segment_op=math_ops.unsorted_segment_sum, 590 rt_input=input_tensor, 591 axis=axis, 592 keepdims=keepdims, 593 name=(name or 'RaggedReduceSum')) 594 595 596def reduce_prod(input_tensor, axis=None, keepdims=None, name=None): 597 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 598 return ragged_reduce_aggregate( 599 reduce_op=math_ops.reduce_prod, 600 unsorted_segment_op=math_ops.unsorted_segment_prod, 601 rt_input=input_tensor, 602 axis=axis, 603 keepdims=keepdims, 604 name=(name or 'RaggedReduceProd')) 605 606 607def reduce_min(input_tensor, axis=None, keepdims=None, name=None): 608 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 609 return ragged_reduce_aggregate( 610 reduce_op=math_ops.reduce_min, 611 unsorted_segment_op=math_ops.unsorted_segment_min, 612 rt_input=input_tensor, 613 axis=axis, 614 keepdims=keepdims, 615 name=(name or 'RaggedReduceMin')) 616 617 618def reduce_max(input_tensor, axis=None, keepdims=None, name=None): 619 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 620 return ragged_reduce_aggregate( 621 reduce_op=math_ops.reduce_max, 622 unsorted_segment_op=math_ops.unsorted_segment_max, 623 rt_input=input_tensor, 624 axis=axis, 625 keepdims=keepdims, 626 name=(name or 'RaggedReduceMax')) 627 628 629def reduce_mean(input_tensor, axis=None, keepdims=None, name=None): 630 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 631 with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]): 632 total = reduce_sum(input_tensor, axis, keepdims) 633 if ragged_tensor.is_ragged(input_tensor): 634 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 635 array_ops.ones_like(input_tensor.flat_values), 636 input_tensor.nested_row_splits, 637 validate=False) 638 else: 639 ones = array_ops.ones_like(input_tensor) 640 count = reduce_sum(ones, axis, keepdims) 641 if ragged_tensor.is_ragged(total): 642 return ragged_tensor.RaggedTensor.from_nested_row_splits( 643 total.flat_values / count.flat_values, 644 total.nested_row_splits, 645 validate=False) 646 else: 647 return total / count 648 649 650def reduce_variance(input_tensor, axis=None, keepdims=False, name=None): 651 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 652 with ops.name_scope(name, 'RaggedReduceVariance', [input_tensor, axis]): 653 square_of_input = math_ops.square(input_tensor) 654 mean_of_square = reduce_mean(square_of_input, axis=axis, keepdims=keepdims) 655 mean = reduce_mean(input_tensor, axis=axis, keepdims=keepdims) 656 square_of_mean = math_ops.square(mean) 657 return mean_of_square - square_of_mean 658 659 660def reduce_std(input_tensor, axis=None, keepdims=False, name=None): 661 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 662 with ops.name_scope(name, 'RaggedReduceStd', [input_tensor, axis]): 663 variance = reduce_variance(input_tensor, axis=axis, keepdims=keepdims) 664 return math_ops.sqrt(variance) 665 666 667def _cast(input_tensor, dtype): 668 return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor, 669 dtype) 670 671 672def reduce_all(input_tensor, axis=None, keepdims=None, name=None): 673 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 674 with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]): 675 return _cast( 676 reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims), 677 dtypes.bool) 678 679 680def reduce_any(input_tensor, axis=None, keepdims=None, name=None): 681 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 682 with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]): 683 return _cast( 684 reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims), 685 dtypes.bool) 686 687 688def _set_ragged_reduce_docstring(func, combination, combined, default, example): 689 func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict( 690 combination=combination, 691 combined=combined, 692 default=default, 693 example=example) 694 695 696_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0', 697 _RAGGED_REDUCE_SUM_EXAMPLE) 698_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', 699 _RAGGED_REDUCE_PROD_EXAMPLE) 700_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', 701 '`input_tensor.dtype.min`', 702 _RAGGED_REDUCE_MIN_EXAMPLE) 703_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', 704 '`input_tensor.dtype.max`', 705 _RAGGED_REDUCE_MAX_EXAMPLE) 706_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', 707 _RAGGED_REDUCE_MEAN_EXAMPLE) 708_set_ragged_reduce_docstring(reduce_variance, 'variance', 'averaged', 'NaN', 709 _RAGGED_REDUCE_VARIANCE_EXAMPLE) 710_set_ragged_reduce_docstring(reduce_std, 'std', 'averaged', 'NaN', 711 _RAGGED_REDUCE_STD_EXAMPLE) 712_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True', 713 _RAGGED_REDUCE_ALL_EXAMPLE) 714_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False', 715 _RAGGED_REDUCE_ANY_EXAMPLE) 716 717 718#=============================================================================== 719# ragged.matmul 720#=============================================================================== 721def matmul(a, 722 b, 723 transpose_a=False, 724 transpose_b=False, 725 adjoint_a=False, 726 adjoint_b=False, 727 a_is_sparse=False, 728 b_is_sparse=False, 729 output_type=None, 730 name=None): 731 """Multiplies matrix `a` by matrix `b`. 732 733 If all transpose or adjoint attributes are `False` then: 734 735 ``` 736 output[..., i, j] = sum_k (a[..., i, k] * b[..., k, j]), for all indices i, j. 737 ``` 738 739 The inputs `a` and `b` must have `rank >= 2`, where the outermost `rank - 2` 740 dimensions are batch dimensions. The inputs must have the same dtype. See 741 `tf.matmul` for more information. 742 743 Args: 744 a: `tf.Tensor` or `RaggedTensor` with `rank > 1`. 745 b: `tf.Tensor` or `RaggedTensor` with same type and rank as `a`. 746 transpose_a: If `True`, `a` is transposed before multiplication. 747 transpose_b: If `True`, `b` is transposed before multiplication. 748 adjoint_a: If `True`, `a` is conjugated & transposed before multiplication. 749 adjoint_b: If `True`, `b` is conjugated & transposed before multiplication. 750 a_is_sparse: If `True`, optimize assuming `a` is mostly zero. 751 b_is_sparse: If `True`, optimize assuming `b` is mostly zero. 752 output_type: The output datatype (optional). 753 name: Name for the operation (optional). 754 755 Returns: 756 A `Tensor` or `RaggedTensor` with the same rank and shape as `a`, where 757 each inner-most matrix is the product of the corresponding matrices in `a` 758 and `b`. 759 """ 760 if transpose_a and adjoint_a: 761 raise ValueError('Only one of transpose_a and adjoint_a can be True.') 762 if transpose_b and adjoint_b: 763 raise ValueError('Only one of transpose_b and adjoint_b can be True.') 764 765 kwargs = dict( 766 transpose_a=transpose_a, 767 transpose_b=transpose_b, 768 adjoint_a=adjoint_a, 769 adjoint_b=adjoint_b, 770 a_is_sparse=a_is_sparse, 771 b_is_sparse=b_is_sparse, 772 output_type=output_type) 773 774 with ops.name_scope(name, 'RaggedMatMul', [a, b]) as name: 775 a = ragged_tensor.convert_to_tensor_or_ragged_tensor(a, name='a') 776 b = ragged_tensor.convert_to_tensor_or_ragged_tensor(b, name='b') 777 778 a_is_ragged = isinstance(a, ragged_tensor.RaggedTensor) 779 b_is_ragged = isinstance(b, ragged_tensor.RaggedTensor) 780 if not (a_is_ragged or b_is_ragged): 781 return math_ops.matmul(a, b, **kwargs) 782 783 if a.dtype != b.dtype: 784 raise ValueError('`a` and `b` must have the same dtype.') 785 786 # TODO(edloper): Support broadcasting inputs. (Broadcast support is not 787 # documented by https://www.tensorflow.org/api_docs/python/tf/linalg/matmul, 788 # but it is supported by the op.) 789 790 # Find the rank of the input tensors. 791 if a.shape.rank is None: 792 if b.shape.rank is None: 793 raise ValueError('matmul requires at least one input to have known ' 794 'rank if either input is ragged.') 795 rank = b.shape.rank 796 else: 797 if b.shape.rank is not None and a.shape.rank != b.shape.rank: 798 raise ValueError('`a` and `b` must have the same rank.') 799 rank = a.shape.rank 800 801 # At least one of `a` and `b` is ragged; and ragged tensors always have 802 # rank>=2. 803 if rank < 2: 804 # This can happen if e.g. `a` is a 1D dense tensor and `b` is a 805 # ragged tensor with unknown rank. Since ragged tensors always have 806 # `rank>=2`, this implies that `a` and `b` have different ranks. 807 raise ValueError('`a` and `b` must have the same rank.') 808 809 # Rank>3: We have multiple batch dimensions. Merge them into a single 810 # batch dimension, recursively call `matmul`, and then restore the original 811 # batch dimension (using a.row_splits). 812 if rank > 3: 813 shape_err = 'Batch dimensions of `a` and `b` do not have the same size.' 814 if not a_is_ragged: 815 a = ragged_tensor.RaggedTensor.from_tensor(a, ragged_rank=1) 816 if not b_is_ragged: 817 b = ragged_tensor.RaggedTensor.from_tensor(b, ragged_rank=1) 818 with ops.control_dependencies([ 819 check_ops.assert_equal(a.row_splits, b.row_splits, message=shape_err) 820 ]): 821 flat_result = matmul(a.values, b.values, **kwargs) 822 return a.with_values(flat_result) 823 824 if rank == 2: 825 return _matmul_2d(a, b, **kwargs) 826 827 assert rank == 3 # I.e., we have a single batch dimension. 828 829 a_ragged_rank = a.ragged_rank if a_is_ragged else 0 830 if a_ragged_rank == 1 and not (b_is_ragged or transpose_a or adjoint_a): 831 # If `a.shape=[B, (I), J]` and `b.shape=[B, J, K], then we can compute 832 # the result with a single dense `matmul`. 833 return _matmul_3d_with_batch_dim_folding(a, b, **kwargs) 834 else: 835 # Otherwie, fall back on using `map_fn`. 836 return _matmul_3d_with_map_fn(a, b, **kwargs) 837 838 839def _matmul_2d(a, b, **kwargs): 840 """Multiplies potentially ragged 2D tensors. 841 842 Args: 843 a: A 2D Tensor or RaggedTensor with `shape=[I, J]` 844 b: A 2D Tensor or RaggedTensor with `shape=[J, K]` 845 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 846 847 Returns: 848 A 2D Tensor with `shape=[I, K]`. 849 """ 850 # multiplying `a` and `b` is only well-defined if `a` and `b` are 851 # actually uniform (and just happened to be stored as ragged tensors). 852 # Check that they're uniform, convert them to tf.Tensor. 853 ragged_err = ('The matrices in `a` and `b` may not be ' 854 'ragged in their innermost dimension.') 855 checks = [] 856 if isinstance(a, ragged_tensor.RaggedTensor): 857 original_size = array_ops.size(a.flat_values) 858 a = a.to_tensor() 859 checks.append( 860 check_ops.assert_equal( 861 original_size, array_ops.size(a), message=ragged_err)) 862 if isinstance(b, ragged_tensor.RaggedTensor): 863 original_size = array_ops.size(b.flat_values) 864 b = b.to_tensor() 865 checks.append( 866 check_ops.assert_equal( 867 original_size, array_ops.size(b), message=ragged_err)) 868 with ops.control_dependencies(checks): 869 return math_ops.matmul(a, b, **kwargs) 870 871 872def _matmul_3d_with_map_fn(a, b, **kwargs): 873 """Multiplies batches of 2D matrices using map_fn. 874 875 `output[n, i, k]` = sum_j (a[n, i, j] * b[n, j, k])` (for all `n`, `i`, `k`). 876 877 Requires that `a[n, i].nrows()` == `b[n].nrows()` (for all `n` and `i`). 878 879 Args: 880 a: A 3D Tensor or RaggedTensor with `shape=[B, I, J]`, where dimensions `I` 881 and `J` may be ragged. 882 b: A 3D Tensor or RaggedTensor with `shape=[B, J, K]`, where dimensions `J` 883 and `K` may be ragged. 884 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 885 886 Returns: 887 A 3D RaggedTensor with `shape=[B, (I), (K)]`. 888 """ 889 if isinstance(b, ragged_tensor.RaggedTensor) and b.ragged_rank == 2: 890 output_ragged_rank = 2 891 else: 892 output_ragged_rank = 1 893 894 def single_batch_matmul(x): 895 out = _matmul_2d(x[0], x[1], **kwargs) 896 if output_ragged_rank == 2: 897 out = ragged_tensor.RaggedTensor.from_tensor(out) 898 return out 899 900 fn_out_shape = None # Figure out proper shape. 901 row_splits_dtype = ( 902 a.row_splits.dtype 903 if isinstance(a, ragged_tensor.RaggedTensor) else b.row_splits.dtype) 904 output_type = kwargs['output_type'] 905 if output_type is None: 906 output_type = a.dtype 907 spec = ragged_tensor.RaggedTensorSpec( 908 shape=fn_out_shape, 909 dtype=output_type, 910 ragged_rank=output_ragged_rank - 1, 911 row_splits_dtype=row_splits_dtype) 912 result = map_fn.map_fn( 913 single_batch_matmul, elems=(a, b), fn_output_signature=spec) 914 915 # map_fn loses shape information; restore it, where possible. 916 # pylint: disable=protected-access 917 if kwargs.get('transpose_a') or kwargs.get('adjoint_a'): 918 result._set_shape(a.shape[:-2] + a.shape[-1:] + [None]) 919 else: 920 result._set_shape(a.shape[:-2] + a.shape[-2:-1] + [None]) 921 if kwargs.get('transpose_b') or kwargs.get('adjoint_b'): 922 result._set_shape(b.shape[:-2] + [None] + b.shape[-2:-1]) 923 else: 924 result._set_shape(b.shape[:-2] + [None] + b.shape[-1:]) 925 926 return result 927 928 929def _matmul_3d_with_batch_dim_folding(a, b, **kwargs): 930 """Multiply batches of 2D matrices where only `a.shape[1]` is ragged. 931 932 Args: 933 a: A RaggedTensor with `shape=[B, (I), J]`. (ragged_rank must be 1.) 934 b: A Tensor with `shape=[B, J, K]` 935 **kwargs: Additional arguments for `tf.matmul` (e.g. transpose_a). 936 transpose_a and adjoint_a must not be true. 937 938 Returns: 939 A RaggedTensor with `shape=[B, (I), K]. 940 """ 941 # reshaped_a.shape = [sum(i_1, i_2, ..., i_B), 1, J] 942 reshaped_a = array_ops.expand_dims(a.values, 1) 943 # reshaped_b.shape = [sum(i_1, i_2, ..., i_B), J, K] 944 reshaped_b = array_ops.repeat(b, a.row_lengths(), axis=0) 945 # flat_result.shape = [sum(i_1, i_2, ..., i_B), 1, K] 946 flat_result = math_ops.matmul(reshaped_a, reshaped_b, **kwargs) 947 # result.shape = [B, (I), K] 948 return a.with_values(array_ops.squeeze(flat_result, axis=1)) 949 950 951#=============================================================================== 952# ragged.softmax 953#=============================================================================== 954def softmax(logits, axis=None, name=None): 955 """Computes softmax activations. 956 957 Used for multi-class predictions. The sum of all outputs generated by softmax 958 is 1. 959 960 This function performs the equivalent of 961 962 softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) 963 964 Example usage: 965 966 >>> softmax = tf.nn.softmax([-1, 0., 1.]) 967 >>> softmax 968 <tf.Tensor: shape=(3,), dtype=float32, 969 numpy=array([0.09003057, 0.24472848, 0.66524094], dtype=float32)> 970 >>> sum(softmax) 971 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 972 973 Args: 974 logits: A non-empty `Tensor`. Must be one of the following types: `half`, 975 `float32`, `float64`. 976 axis: The dimension softmax would be performed on. The default is -1 which 977 indicates the last dimension. 978 name: A name for the operation (optional). 979 980 Returns: 981 A `Tensor`. Has the same type and shape as `logits`. 982 983 Raises: 984 InvalidArgumentError: if `logits` is empty or `axis` is beyond the last 985 dimension of `logits`. 986 """ 987 if axis is None: 988 axis = -1 989 990 with ops.name_scope(name, 'RaggedSoftmax', [logits]) as name: 991 logits_exp = math_ops.exp(logits) 992 denominator = reduce_sum(logits_exp, axis=axis, keepdims=True) 993 return math_ops.divide(logits_exp, denominator) 994