1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for ragged tensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import gen_ragged_math_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.ragged import ragged_functional_ops 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.ops.ragged import segment_id_ops 33from tensorflow.python.util import dispatch 34from tensorflow.python.util.tf_export import tf_export 35 36 37#=============================================================================== 38# ragged.range 39#=============================================================================== 40# pylint: disable=redefined-builtin 41@tf_export('ragged.range') 42@dispatch.add_dispatch_support 43def range(starts, 44 limits=None, 45 deltas=1, 46 dtype=None, 47 name=None, 48 row_splits_dtype=dtypes.int64): 49 """Returns a `RaggedTensor` containing the specified sequences of numbers. 50 51 Each row of the returned `RaggedTensor` contains a single sequence: 52 53 ```python 54 ragged.range(starts, limits, deltas)[i] == 55 tf.range(starts[i], limits[i], deltas[i]) 56 ``` 57 58 If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an 59 empty list. Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then 60 `output[i]` will be an empty list. This behavior is consistent with the 61 Python `range` function, but differs from the `tf.range` op, which returns 62 an error for these cases. 63 64 Examples: 65 66 >>> tf.ragged.range([3, 5, 2]).to_list() 67 [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]] 68 >>> tf.ragged.range([0, 5, 8], [3, 3, 12]).to_list() 69 [[0, 1, 2], [], [8, 9, 10, 11]] 70 >>> tf.ragged.range([0, 5, 8], [3, 3, 12], 2).to_list() 71 [[0, 2], [], [8, 10]] 72 73 The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. 74 The vector inputs must all have the same size. Scalar inputs are broadcast 75 to match the size of the vector inputs. 76 77 Args: 78 starts: Vector or scalar `Tensor`. Specifies the first entry for each range 79 if `limits` is not `None`; otherwise, specifies the range limits, and the 80 first entries default to `0`. 81 limits: Vector or scalar `Tensor`. Specifies the exclusive upper limits for 82 each range. 83 deltas: Vector or scalar `Tensor`. Specifies the increment for each range. 84 Defaults to `1`. 85 dtype: The type of the elements of the resulting tensor. If not specified, 86 then a value is chosen based on the other args. 87 name: A name for the operation. 88 row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` 89 tensor. One of `tf.int32` or `tf.int64`. 90 91 Returns: 92 A `RaggedTensor` of type `dtype` with `ragged_rank=1`. 93 """ 94 row_splits_dtype = dtypes.as_dtype(row_splits_dtype) 95 if limits is None: 96 starts, limits = 0, starts 97 98 with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name: 99 starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts') 100 limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits') 101 deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas') 102 103 # infer dtype if not explicitly provided 104 if dtype is None: 105 starts, limits, deltas = _infer_matching_dtype( 106 [starts, limits, deltas], 107 [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) 108 109 result = gen_ragged_math_ops.ragged_range( 110 starts, limits, deltas, Tsplits=row_splits_dtype, name=name) 111 return ragged_tensor.RaggedTensor.from_row_splits( 112 result.rt_dense_values, result.rt_nested_splits, validate=False) 113 114 115def _infer_matching_dtype(tensors, dtype_hierarchy): 116 """Infers a matching dtype for tensors, and casts them to that dtype.""" 117 assert all(t.dtype in dtype_hierarchy for t in tensors) 118 inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index) 119 return [math_ops.cast(t, inferred_dtype) for t in tensors] 120 121 122ops.no_gradient('RaggedRange') 123 124#=============================================================================== 125# ragged_segment_<AGGREGATE> 126#=============================================================================== 127 128# Docstring template used for the raggged_segment_<AGGREGATE> ops. 129_RAGGED_SEGMENT_DOCSTRING = """\ 130Computes the %(combination)s along segments of a RaggedTensor. 131 132 Returns a RaggedTensor `output` with `num_segments` rows, where the row 133 `output[i]` is formed by taking the %(combination)s of all rows of `data` 134 whose corresponding `segment_id` is `i`. 135 136 The length of the row `output[i]` will be the maximum of the lengths of 137 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 138 rows correspond to a given segment ID, then the output row for that segment 139 ID will be empty. 140 141 Args: 142 data: A `RaggedTensor` containing the values to combine. 143 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 144 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 145 Must be greater than or equal to zero, and less than `num_segments`. 146 `segment_ids` is not required to be sorted. 147 num_segments: An `int32` or `int64` scalar specifying the number of 148 distinct segment ids. 149 name: A name prefix for the returned tensor (optional). 150 Returns: 151 A `RaggedTensor` containing the %(combined)s values. The returned tensor 152 has the same dtype as `data`, and its shape is 153 `[num_segments] + data.shape[segment_ids.rank:]`. 154 Raises: 155 ValueError: If `segment_ids.shape` is not a prefix of `data.shape`. 156""" 157 158 159def _ragged_segment_aggregate(unsorted_segment_op, 160 data, 161 segment_ids, 162 num_segments, 163 separator=None, 164 name=None): 165 """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. 166 167 Returns a RaggedTensor `output` with `num_segments` rows, where the row 168 `output[i]` is formed by combining all rows of `data` whose corresponding 169 `segment_id` is `i`. The values in each row are combined using 170 `unsorted_segment_op`. 171 172 The length of the row `output[i]` will be the maximum of the lengths of 173 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 174 rows correspond to a given segment ID, then the output row for that segment 175 ID will be empty. 176 177 Args: 178 unsorted_segment_op: The tensorflow `op` that should be used to combine 179 values in each row. Must have the same signature and basic behavior as 180 `unsorted_segment_sum`, `unsorted_segment_max`, etc. 181 data: A `RaggedTensor` containing the values to be combined. 182 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 183 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 184 `segment_ids` is not required to be sorted. 185 num_segments: An `int32` or `int64` scalar. 186 separator: An optional string. Defaults to None. The separator to use when 187 joining. Only used for string types. 188 name: A name prefix for the returned tensor (optional). 189 190 Returns: 191 A `RaggedTensor` containing the aggregated values. The returned tensor 192 has the same dtype as `data`, and its shape is 193 `[num_segments] + data.shape[segment_ids.rank:]`. 194 Raises: 195 ValueError: If segment_ids.shape is not a prefix of data.shape. 196 """ 197 if not (ragged_tensor.is_ragged(data) or 198 ragged_tensor.is_ragged(segment_ids)): 199 if separator is not None: 200 # It uses unsorted_segment_join. 201 return unsorted_segment_op(data, segment_ids, num_segments, separator, 202 name) 203 else: 204 return unsorted_segment_op(data, segment_ids, num_segments, name) 205 206 with ops.name_scope(name, 'RaggedSegment', 207 [data, segment_ids, num_segments]) as name: 208 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 209 segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( 210 segment_ids, name='segment_ids') 211 data, segment_ids = ragged_tensor.match_row_splits_dtypes(data, segment_ids) 212 if segment_ids.dtype not in (dtypes.int32, dtypes.int64): 213 raise ValueError('segment_ids must have dtype int32 or int64.') 214 215 if ragged_tensor.is_ragged(segment_ids): 216 if not ragged_tensor.is_ragged(data): 217 raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 218 'but segment_ids is ragged and data is not.') 219 check_splits = check_ops.assert_equal( 220 segment_ids.row_splits, 221 data.row_splits, 222 message='segment_ids.shape must be a prefix of data.shape') 223 with ops.control_dependencies([check_splits]): 224 return _ragged_segment_aggregate(unsorted_segment_op, data.values, 225 segment_ids.values, num_segments, 226 separator) 227 228 # Find the length of each row in data. (shape=[data_nrows]) 229 data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] 230 231 # Find the length that each output row will have. The length of the row 232 # corresponding to segment `id` is `max(data_row_lengths[i])` where 233 # `segment_ids[i]=id`. (shape=[output_nrows]) 234 output_row_lengths = math_ops.maximum( 235 math_ops.unsorted_segment_max(data_row_lengths, segment_ids, 236 num_segments), 0) 237 238 # Build the splits tensor for the output RaggedTensor. 239 output_splits = array_ops.concat([ 240 array_ops.zeros([1], output_row_lengths.dtype), 241 math_ops.cumsum(output_row_lengths) 242 ], 243 axis=0) 244 245 # For each row in `data`, find the start & limit position where that row's 246 # values will be aggregated in output.values. 247 data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) 248 data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths 249 250 # For each value in `data.values`, find the position where it will 251 # aggregated in `output.values`. 252 # Get the target output values index for each data values index. 253 data_val_to_out_val_index = range(data_row_to_out_row_start, 254 data_row_to_out_row_limit).values 255 256 # Recursively aggregate the values. 257 output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, 258 data_val_to_out_val_index, 259 output_splits[-1], separator) 260 return ragged_tensor.RaggedTensor.from_row_splits( 261 output_values, output_splits, validate=False) 262 263 264def segment_sum(data, segment_ids, num_segments, name=None): 265 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 266 return _ragged_segment_aggregate( 267 math_ops.unsorted_segment_sum, 268 data=data, 269 segment_ids=segment_ids, 270 num_segments=num_segments, 271 name=(name or 'RaggedSegmentSum')) 272 273 274def segment_prod(data, segment_ids, num_segments, name=None): 275 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 276 return _ragged_segment_aggregate( 277 math_ops.unsorted_segment_prod, 278 data=data, 279 segment_ids=segment_ids, 280 num_segments=num_segments, 281 name=(name or 'RaggedSegmentProd')) 282 283 284def segment_min(data, segment_ids, num_segments, name=None): 285 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 286 return _ragged_segment_aggregate( 287 math_ops.unsorted_segment_min, 288 data=data, 289 segment_ids=segment_ids, 290 num_segments=num_segments, 291 name=(name or 'RaggedSegmentMin')) 292 293 294def segment_max(data, segment_ids, num_segments, name=None): 295 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 296 return _ragged_segment_aggregate( 297 math_ops.unsorted_segment_max, 298 data=data, 299 segment_ids=segment_ids, 300 num_segments=num_segments, 301 name=(name or 'RaggedSegmentMax')) 302 303 304def segment_mean(data, segment_ids, num_segments, name=None): 305 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 306 with ops.name_scope(name, 'RaggedSegmentMean', 307 [data, segment_ids, num_segments]): 308 total = segment_sum(data, segment_ids, num_segments) 309 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 310 array_ops.ones_like(data.flat_values), 311 data.nested_row_splits, 312 validate=False) 313 count = segment_sum(ones, segment_ids, num_segments) 314 if ragged_tensor.is_ragged(total): 315 return total.with_flat_values(total.flat_values / count.flat_values) 316 else: 317 return total / count 318 319 320def segment_sqrt_n(data, segment_ids, num_segments, name=None): 321 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 322 with ops.name_scope(name, 'RaggedSegmentSqrtN', 323 [data, segment_ids, num_segments]): 324 total = segment_sum(data, segment_ids, num_segments) 325 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 326 array_ops.ones_like(data.flat_values), 327 data.nested_row_splits, 328 validate=False) 329 count = segment_sum(ones, segment_ids, num_segments) 330 if ragged_tensor.is_ragged(total): 331 return total.with_flat_values(total.flat_values / 332 math_ops.sqrt(count.flat_values)) 333 else: 334 return total / math_ops.sqrt(count) 335 336 337def _set_ragged_segment_docstring(func, combination, combined): 338 func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict( 339 combination=combination, combined=combined) 340 341 342_set_ragged_segment_docstring(segment_sum, 'sum', 'summed') 343_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied') 344_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized') 345_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized') 346_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged') 347_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', 348 'summed') 349 350#=============================================================================== 351# ragged_reduce_<AGGREGATE> 352#=============================================================================== 353 354# Docstring template used for ragged_reduce_<AGGREGATE> ops. 355_RAGGED_REDUCE_DOCSTRING = """\ 356Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. 357 358 Reduces `input_tensor` along the dimensions given in `axis` by taking the 359 %(combination)s of values. If a reduced dimension has no elements for 360 some index, then the value for that index will be %(default)s. 361 362 The rank of the tensor is reduced by `1` for each entry in `axis`. If 363 `axis` is not specified, then all dimensions are reduced, and a scalar 364 value is returned. 365 Args: 366 input_tensor: A `RaggedTensor` containing the values to be %(combined)s. 367 axis: The dimensions to reduce. May be `None` (to reduce all axes), an 368 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce 369 a given set of axes), or a `Tensor` with a constant value. Must be in 370 the range `[0, input_tensor.rank]`. 371 name: A name prefix for the returned tensor (optional). 372 Returns: 373 A `RaggedTensor` containing the %(combined)s values. The returned tensor 374 has the same dtype as `data`, and its shape is given by removing the 375 dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank` 376 of the returned tensor is given by substracting any ragged dimensions 377 specified in `axis` from `input_tensor.ragged_rank`. 378 Raises: 379 ValueError: If `axis` contains a `Tensor` whose value is not constant. 380 ####Example: 381 %(example)s 382""" 383_RAGGED_REDUCE_SUM_EXAMPLE = """ 384 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 385 >>> tf.reduce_sum(rt, axis=0).numpy() # = [3+1+9+2, 1+5+6, 4] 386 array([15, 12, 4], dtype=int32) 387 >>> tf.reduce_sum(rt, axis=1).numpy() # = [3+1+4, 1+5, 9, 2+6] 388 array([8, 6, 9, 8], dtype=int32) 389""" 390_RAGGED_REDUCE_PROD_EXAMPLE = """ 391 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 392 >>> tf.reduce_prod(rt, axis=0).numpy() # = [3*1*9*2, 1*5*6, 4] 393 array([54, 30, 4], dtype=int32) 394 >>> tf.reduce_prod(rt, axis=1).numpy() # = [3*1*4, 1*5, 9, 2*6] 395 array([12, 5, 9, 12], dtype=int32) 396""" 397_RAGGED_REDUCE_MIN_EXAMPLE = """ 398 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 399 >>> tf.reduce_min(rt, axis=0).numpy() 400 array([1, 1, 4], dtype=int32) 401 >>> tf.reduce_min(rt, axis=1).numpy() 402 array([1, 1, 9, 2], dtype=int32) 403""" 404_RAGGED_REDUCE_MAX_EXAMPLE = """ 405 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 406 >>> tf.reduce_max(rt, axis=0).numpy() 407 array([9, 6, 4], dtype=int32) 408 >>> tf.reduce_max(rt, axis=1).numpy() 409 array([4, 5, 9, 6], dtype=int32) 410""" 411_RAGGED_REDUCE_MEAN_EXAMPLE = """ 412 >>> rt = tf.ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 413 >>> tf.reduce_mean(rt, axis=0).numpy() 414 array([3.75, 4. , 4. ]) 415 >>> tf.reduce_mean(rt, axis=1).numpy() 416 array([2.66666667, 3. , 9. , 4. ]) 417""" 418_RAGGED_REDUCE_ALL_EXAMPLE = """ 419 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 420 >>> tf.reduce_all(rt, axis=0).numpy() 421 array([False, True, False, True]) 422 >>> tf.reduce_all(rt, axis=1).numpy() 423 array([ True, False, False]) 424""" 425_RAGGED_REDUCE_ANY_EXAMPLE = """ 426 >>> rt = tf.ragged.constant([[True, True], [True, True, False, True], [False, True]]) 427 >>> tf.reduce_any(rt, axis=0).numpy() 428 array([ True, True, False, True]) 429 >>> tf.reduce_any(rt, axis=1).numpy() 430 array([ True, True, True]) 431""" 432 433 434def ragged_reduce_aggregate(reduce_op, 435 unsorted_segment_op, 436 rt_input, 437 axis, 438 keepdims, 439 separator=None, 440 name=None): 441 """Aggregates across axes of a RaggedTensor using the given `Tensor` ops. 442 443 Reduces `rt_input` along the dimensions given in `axis`. The rank of the 444 tensor is reduced by 1 for each entry in `axis`. If `axis` is not specified, 445 then all dimensions are reduced, and a scalar value is returned. 446 447 This op assumes that `reduce_op` and `unsorted_segment_op` are associative; 448 if not, then reducing multiple axes will return incorrect results. (In 449 particular, reducing multiple axes is currently implemented by reducing the 450 axes one at a time.) 451 452 Args: 453 reduce_op: The tensorflow `op` that should be used to reduce values in 454 uniform dimensions. Must have the same signature and basic behavior as 455 `reduce_sum`, `reduce_max`, etc. 456 unsorted_segment_op: The tensorflow `op` that should be used to combine 457 values in ragged dimensions. Must have the same signature and basic 458 behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. 459 rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced. 460 axis: The axis or axes to reduce. May be `None` (to reduce all axes), an 461 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a 462 given set of axes), or a `Tensor` with a constant value. Must be in the 463 range `[0, rt_input.rank)`. 464 keepdims: If true, retains reduced dimensions with length 1. 465 separator: An optional string. Defaults to None. The separator to use when 466 joining. The separator must not be set for non-string data types. (i.e. if 467 separator is not None then it uses string ops) 468 name: A name prefix for the returned tensor (optional). 469 470 Returns: 471 A `RaggedTensor` containing the reduced values. The returned tensor 472 has the same dtype as `data`, and its shape is given by removing the 473 dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` 474 of the returned tensor is given by substracting any ragged dimensions 475 specified in `axis` from `rt_input.ragged_rank`. 476 Raises: 477 ValueError: If `axis` contains a `Tensor` whose value is not constant. 478 """ 479 if not ragged_tensor.is_ragged(rt_input): 480 if separator is None: 481 return reduce_op(rt_input, axis, keepdims=keepdims, name=name) 482 else: 483 # When separator is not None, We infer that dtype is string and 484 # reduce_join will be called. 485 return reduce_op( 486 rt_input, axis, keepdims=keepdims, name=name, separator=separator) 487 488 if isinstance(axis, ops.Tensor): 489 axis = tensor_util.constant_value(axis) 490 if axis is None: 491 raise ValueError('axis must be known at graph construction time.') 492 if isinstance(axis, np.ndarray): 493 axis = axis.tolist() 494 495 # When reducing all axes, just ignore splits & reduce the inner values. 496 if axis is None: 497 result = reduce_op(rt_input.flat_values, None, keepdims=keepdims, name=name) 498 if keepdims: 499 # Expand the result to the input number of dimensions. 500 for _ in rt_input.shape[1:]: 501 result = array_ops.expand_dims(result, axis=0) 502 return result 503 504 with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]): 505 if isinstance(axis, (tuple, list)): 506 if not axis: 507 return rt_input 508 elif len(axis) == 1: 509 axis = axis[0] 510 else: 511 # When reducing multiple axes, as we reduce one at a time (see below), 512 # the negative axis has to be converted to positive at the first run 513 # as the sort with negative axis will have different orders. 514 # See GitHub issue 27497. 515 axis = [ 516 array_ops.get_positive_axis(a, rt_input.shape.ndims, 'axis[%s]' % i, 517 'rank(input_tensor)') 518 for i, a in enumerate(axis) 519 ] 520 # When reducing multiple axes, just reduce one at a time. This is less 521 # efficient, and only works for associative ops. (In particular, it 522 # does not work for reduce_mean.) However, reducing multiple axes at 523 # once will probably require a nontrivial c++ op. 524 axis = sorted(axis) 525 inner_reduced = ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 526 rt_input, axis[-1], keepdims, 527 separator) 528 return ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 529 inner_reduced, axis[:-1], keepdims, 530 separator) 531 532 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 533 rt_input, name='rt_input') 534 535 axis = array_ops.get_positive_axis( 536 axis, rt_input.shape.ndims, ndims_name='rank(input_tensor)') 537 538 if axis == 0: 539 # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N] 540 row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1] 541 num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0) 542 segment_ids = range(row_lengths).values 543 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 544 segment_ids, num_segments, separator) 545 if keepdims: 546 result = array_ops.expand_dims(result, axis=0) 547 return result 548 elif axis == 1: 549 # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N] 550 num_segments = array_ops.shape(rt_input.row_splits)[0] - 1 551 segment_ids = segment_id_ops.row_splits_to_segment_ids( 552 rt_input.row_splits) 553 result = _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 554 segment_ids, num_segments, separator) 555 if keepdims: 556 result = array_ops.expand_dims(result, axis=1) 557 return result 558 else: 559 # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] = 560 # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] 561 return rt_input.with_values( 562 ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 563 rt_input.values, axis - 1, keepdims, 564 separator)) 565 566 567def reduce_sum(input_tensor, axis=None, keepdims=None, name=None): 568 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 569 570 return ragged_reduce_aggregate( 571 reduce_op=math_ops.reduce_sum, 572 unsorted_segment_op=math_ops.unsorted_segment_sum, 573 rt_input=input_tensor, 574 axis=axis, 575 keepdims=keepdims, 576 name=(name or 'RaggedReduceSum')) 577 578 579def reduce_prod(input_tensor, axis=None, keepdims=None, name=None): 580 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 581 return ragged_reduce_aggregate( 582 reduce_op=math_ops.reduce_prod, 583 unsorted_segment_op=math_ops.unsorted_segment_prod, 584 rt_input=input_tensor, 585 axis=axis, 586 keepdims=keepdims, 587 name=(name or 'RaggedReduceProd')) 588 589 590def reduce_min(input_tensor, axis=None, keepdims=None, name=None): 591 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 592 return ragged_reduce_aggregate( 593 reduce_op=math_ops.reduce_min, 594 unsorted_segment_op=math_ops.unsorted_segment_min, 595 rt_input=input_tensor, 596 axis=axis, 597 keepdims=keepdims, 598 name=(name or 'RaggedReduceMin')) 599 600 601def reduce_max(input_tensor, axis=None, keepdims=None, name=None): 602 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 603 return ragged_reduce_aggregate( 604 reduce_op=math_ops.reduce_max, 605 unsorted_segment_op=math_ops.unsorted_segment_max, 606 rt_input=input_tensor, 607 axis=axis, 608 keepdims=keepdims, 609 name=(name or 'RaggedReduceMax')) 610 611 612def reduce_mean(input_tensor, axis=None, keepdims=None, name=None): 613 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 614 with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]): 615 total = reduce_sum(input_tensor, axis, keepdims) 616 if ragged_tensor.is_ragged(input_tensor): 617 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 618 array_ops.ones_like(input_tensor.flat_values), 619 input_tensor.nested_row_splits, 620 validate=False) 621 else: 622 ones = array_ops.ones_like(input_tensor) 623 count = reduce_sum(ones, axis, keepdims) 624 if ragged_tensor.is_ragged(total): 625 return ragged_tensor.RaggedTensor.from_nested_row_splits( 626 total.flat_values / count.flat_values, 627 total.nested_row_splits, 628 validate=False) 629 else: 630 return total / count 631 632 633def _cast(input_tensor, dtype): 634 return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor, 635 dtype) 636 637 638def reduce_all(input_tensor, axis=None, keepdims=None, name=None): 639 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 640 with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]): 641 return _cast( 642 reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims), 643 dtypes.bool) 644 645 646def reduce_any(input_tensor, axis=None, keepdims=None, name=None): 647 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 648 with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]): 649 return _cast( 650 reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims), 651 dtypes.bool) 652 653 654def _set_ragged_reduce_docstring(func, combination, combined, default, example): 655 func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict( 656 combination=combination, 657 combined=combined, 658 default=default, 659 example=example) 660 661 662_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0', 663 _RAGGED_REDUCE_SUM_EXAMPLE) 664_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', 665 _RAGGED_REDUCE_PROD_EXAMPLE) 666_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', 667 '`input_tensor.dtype.min`', 668 _RAGGED_REDUCE_MIN_EXAMPLE) 669_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', 670 '`input_tensor.dtype.max`', 671 _RAGGED_REDUCE_MAX_EXAMPLE) 672_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', 673 _RAGGED_REDUCE_MEAN_EXAMPLE) 674 675_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True', 676 _RAGGED_REDUCE_ALL_EXAMPLE) 677_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False', 678 _RAGGED_REDUCE_ANY_EXAMPLE) 679