1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for ragged tensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import gen_ragged_math_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops.ragged import ragged_functional_ops 31from tensorflow.python.ops.ragged import ragged_tensor 32from tensorflow.python.ops.ragged import ragged_util 33from tensorflow.python.ops.ragged import segment_id_ops 34from tensorflow.python.util.tf_export import tf_export 35 36 37#=============================================================================== 38# ragged.range 39#=============================================================================== 40# pylint: disable=redefined-builtin 41@tf_export('ragged.range') 42def range(starts, limits=None, deltas=1, dtype=None, name=None): 43 """Returns a `RaggedTensor` containing the specified sequences of numbers. 44 45 Each row of the returned `RaggedTensor` contains a single sequence: 46 47 ```python 48 ragged.range(starts, limits, deltas)[i] == 49 tf.range(starts[i], limits[i], deltas[i]) 50 ``` 51 52 If `start[i] < limits[i] and deltas[i] > 0`, then `output[i]` will be an 53 empty list. Similarly, if `start[i] > limits[i] and deltas[i] < 0`, then 54 `output[i]` will be an empty list. This behavior is consistent with the 55 Python `range` function, but differs from the `tf.range` op, which returns 56 an error for these cases. 57 58 Examples: 59 60 ```python 61 >>> ragged.range([3, 5, 2]).eval().tolist() 62 [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]] 63 >>> ragged.range([0, 5, 8], [3, 3, 12]).eval().tolist() 64 [[0, 1, 2], [], [8, 9, 10, 11]] 65 >>> ragged.range([0, 5, 8], [3, 3, 12], 2).eval().tolist() 66 [[0, 2], [], [8, 10]] 67 ``` 68 69 The input tensors `starts`, `limits`, and `deltas` may be scalars or vectors. 70 The vector inputs must all have the same size. Scalar inputs are broadcast 71 to match the size of the vector inputs. 72 73 Args: 74 starts: Vector or scalar `Tensor`. Specifies the first entry for each range 75 if `limits` is not `None`; otherwise, specifies the range limits, and the 76 first entries default to `0`. 77 limits: Vector or scalar `Tensor`. Specifies the exclusive upper limits for 78 each range. 79 deltas: Vector or scalar `Tensor`. Specifies the increment for each range. 80 Defaults to `1`. 81 dtype: The type of the elements of the resulting tensor. If not specified, 82 then a value is chosen based on the other args. 83 name: A name for the operation. 84 85 Returns: 86 A `RaggedTensor` of type `dtype` with `ragged_rank=1`. 87 """ 88 if limits is None: 89 starts, limits = 0, starts 90 91 with ops.name_scope(name, 'RaggedRange', [starts, limits, deltas]) as name: 92 starts = ops.convert_to_tensor(starts, dtype=dtype, name='starts') 93 limits = ops.convert_to_tensor(limits, dtype=dtype, name='limits') 94 deltas = ops.convert_to_tensor(deltas, dtype=dtype, name='deltas') 95 96 # infer dtype if not explicitly provided 97 if dtype is None: 98 starts, limits, deltas = _infer_matching_dtype( 99 [starts, limits, deltas], 100 [dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) 101 102 result = gen_ragged_math_ops.ragged_range(starts, limits, deltas, name=name) 103 return ragged_tensor.RaggedTensor.from_row_splits(result.rt_dense_values, 104 result.rt_nested_splits) 105 106 107def _infer_matching_dtype(tensors, dtype_hierarchy): 108 """Infers a matching dtype for tensors, and casts them to that dtype.""" 109 assert all(t.dtype in dtype_hierarchy for t in tensors) 110 inferred_dtype = max([t.dtype for t in tensors], key=dtype_hierarchy.index) 111 return [math_ops.cast(t, inferred_dtype) for t in tensors] 112 113 114#=============================================================================== 115# ragged_segment_<AGGREGATE> 116#=============================================================================== 117 118# Docstring template used for the raggged_segment_<AGGREGATE> ops. 119_RAGGED_SEGMENT_DOCSTRING = """\ 120Computes the %(combination)s along segments of a RaggedTensor. 121 122 Returns a RaggedTensor `output` with `num_segments` rows, where the row 123 `output[i]` is formed by taking the %(combination)s of all rows of `data` 124 whose corresponding `segment_id` is `i`. 125 126 The length of the row `output[i]` will be the maximum of the lengths of 127 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 128 rows correspond to a given segment ID, then the output row for that segment 129 ID will be empty. 130 131 Args: 132 data: A `RaggedTensor` containing the values to combine. 133 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 134 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 135 Must be greater than or equal to zero, and less than `num_segments`. 136 `segment_ids` is not required to be sorted. 137 num_segments: An `int32` or `int64` scalar specifying the number of 138 distinct segment ids. 139 name: A name prefix for the returned tensor (optional). 140 Returns: 141 A `RaggedTensor` containing the %(combined)s values. The returned tensor 142 has the same dtype as `data`, and its shape is 143 `[num_segments] + data.shape[segment_ids.rank:]`. 144 Raises: 145 ValueError: If `segment_ids.shape` is not a prefix of `data.shape`. 146""" 147 148 149def _ragged_segment_aggregate(unsorted_segment_op, 150 data, 151 segment_ids, 152 num_segments, 153 name=None): 154 """Aggregates along segments of a RaggedTensor using `unsorted_segment_op`. 155 156 Returns a RaggedTensor `output` with `num_segments` rows, where the row 157 `output[i]` is formed by combining all rows of `data` whose corresponding 158 `segment_id` is `i`. The values in each row are combined using 159 `unsorted_segment_op`. 160 161 The length of the row `output[i]` will be the maximum of the lengths of 162 all rows of `data` whose corresponding `segment_id` is `i`. If no `data` 163 rows correspond to a given segment ID, then the output row for that segment 164 ID will be empty. 165 166 Args: 167 unsorted_segment_op: The tensorflow `op` that should be used to combine 168 values in each row. Must have the same signature and basic behavior as 169 `unsorted_segment_sum`, `unsorted_segment_max`, etc. 170 data: A `RaggedTensor` containing the values to be combined. 171 segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or 172 `int32`. `segment_ids.shape` must be a prefix of `data.shape`. 173 `segment_ids` is not required to be sorted. 174 num_segments: An `int32` or `int64` scalar. 175 name: A name prefix for the returned tensor (optional). 176 177 Returns: 178 A `RaggedTensor` containing the aggregated values. The returned tensor 179 has the same dtype as `data`, and its shape is 180 `[num_segments] + data.shape[segment_ids.rank:]`. 181 Raises: 182 ValueError: If segment_ids.shape is not a prefix of data.shape. 183 """ 184 if not (ragged_tensor.is_ragged(data) or 185 ragged_tensor.is_ragged(segment_ids)): 186 return unsorted_segment_op(data, segment_ids, num_segments, name) 187 188 with ops.name_scope(name, 'RaggedSegment', 189 [data, segment_ids, num_segments]) as name: 190 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data') 191 segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor( 192 segment_ids, name='segment_ids') 193 194 if ragged_tensor.is_ragged(segment_ids): 195 if not ragged_tensor.is_ragged(data): 196 raise ValueError('segment_ids.shape must be a prefix of data.shape, ' 197 'but segment_ids is ragged and data is not.') 198 check_splits = check_ops.assert_equal( 199 segment_ids.row_splits, 200 data.row_splits, 201 message='segment_ids.shape must be a prefix of data.shape') 202 with ops.control_dependencies([check_splits]): 203 return _ragged_segment_aggregate(unsorted_segment_op, data.values, 204 segment_ids.values, num_segments, name) 205 206 segment_ids = math_ops.cast(segment_ids, dtypes.int64) 207 208 # Find the length of each row in data. (dtype=int64, shape=[data_nrows]) 209 data_row_lengths = data.row_splits[1:] - data.row_splits[:-1] 210 211 # Find the length that each output row will have. The length of the row 212 # corresponding to segment `id` is `max(data_row_lengths[i])` where 213 # `segment_ids[i]=id`. (dtype=int64, shape=[output_nrows]) 214 output_row_lengths = math_ops.maximum( 215 math_ops.unsorted_segment_max(data_row_lengths, segment_ids, 216 num_segments), 0) 217 assert output_row_lengths.dtype == dtypes.int64 218 219 # Build the splits tensor for the output RaggedTensor. 220 output_splits = array_ops.concat([ 221 array_ops.zeros([1], dtypes.int64), 222 math_ops.cumsum(output_row_lengths) 223 ], 224 axis=0) 225 226 # For each row in `data`, find the start & limit position where that row's 227 # values will be aggregated in output.values. 228 data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids) 229 data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths 230 231 # For each value in `data.values`, find the position where it will 232 # aggregated in `output.values`. 233 # Get the target output values index for each data values index. 234 data_val_to_out_val_index = range(data_row_to_out_row_start, 235 data_row_to_out_row_limit).values 236 237 # Recursively aggregate the values. 238 output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values, 239 data_val_to_out_val_index, 240 output_splits[-1]) 241 return ragged_tensor.RaggedTensor.from_row_splits(output_values, 242 output_splits) 243 244 245def segment_sum(data, segment_ids, num_segments, name=None): 246 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 247 return _ragged_segment_aggregate(math_ops.unsorted_segment_sum, data, 248 segment_ids, num_segments, name or 249 'RaggedSegmentSum') 250 251 252def segment_prod(data, segment_ids, num_segments, name=None): 253 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 254 return _ragged_segment_aggregate(math_ops.unsorted_segment_prod, data, 255 segment_ids, num_segments, name or 256 'RaggedSegmentProd') 257 258 259def segment_min(data, segment_ids, num_segments, name=None): 260 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 261 return _ragged_segment_aggregate(math_ops.unsorted_segment_min, data, 262 segment_ids, num_segments, name or 263 'RaggedSegmentMin') 264 265 266def segment_max(data, segment_ids, num_segments, name=None): 267 # For docs, see: _RAGGED_SEGMENT_DOCSTRING 268 return _ragged_segment_aggregate(math_ops.unsorted_segment_max, data, 269 segment_ids, num_segments, name or 270 'RaggedSegmentMax') 271 272 273def segment_mean(data, segment_ids, num_segments, name=None): 274 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 275 with ops.name_scope(name, 'RaggedSegmentMean', 276 [data, segment_ids, num_segments]): 277 total = segment_sum(data, segment_ids, num_segments) 278 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 279 array_ops.ones_like(data.flat_values), data.nested_row_splits) 280 count = segment_sum(ones, segment_ids, num_segments) 281 if ragged_tensor.is_ragged(total): 282 return total.with_flat_values(total.flat_values / count.flat_values) 283 else: 284 return total / count 285 286 287def segment_sqrt_n(data, segment_ids, num_segments, name=None): 288 """For docs, see: _RAGGED_SEGMENT_DOCSTRING.""" 289 with ops.name_scope(name, 'RaggedSegmentSqrtN', 290 [data, segment_ids, num_segments]): 291 total = segment_sum(data, segment_ids, num_segments) 292 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 293 array_ops.ones_like(data.flat_values), data.nested_row_splits) 294 count = segment_sum(ones, segment_ids, num_segments) 295 if ragged_tensor.is_ragged(total): 296 return total.with_flat_values( 297 total.flat_values / math_ops.sqrt(count.flat_values)) 298 else: 299 return total / math_ops.sqrt(count) 300 301 302def _set_ragged_segment_docstring(func, combination, combined): 303 func.__doc__ = _RAGGED_SEGMENT_DOCSTRING % dict( 304 combination=combination, combined=combined) 305 306 307_set_ragged_segment_docstring(segment_sum, 'sum', 'summed') 308_set_ragged_segment_docstring(segment_prod, 'product', 'multiplied') 309_set_ragged_segment_docstring(segment_min, 'minimum', 'minimized') 310_set_ragged_segment_docstring(segment_max, 'maximum', 'maximized') 311_set_ragged_segment_docstring(segment_mean, 'mean', 'averaged') 312_set_ragged_segment_docstring(segment_sqrt_n, 'sum divided by sqrt(N)', 313 'summed') 314 315#=============================================================================== 316# ragged_reduce_<AGGREGATE> 317#=============================================================================== 318 319# Docstring template used for ragged_reduce_<AGGREGATE> ops. 320_RAGGED_REDUCE_DOCSTRING = """\ 321Computes the %(combination)s of elements across dimensions of a `RaggedTensor`. 322 323 Reduces `input_tensor` along the dimensions given in `axis` by taking the 324 %(combination)s of values. If a reduced dimension has no elements for 325 some index, then the value for that index will be %(default)s. 326 327 The rank of the tensor is reduced by `1` for each entry in `axis`. If 328 `axis` is not specified, then all dimensions are reduced, and a scalar 329 value is returned. 330 Args: 331 input_tensor: A `RaggedTensor` containing the values to be %(combined)s. 332 axis: The dimensions to reduce. May be `None` (to reduce all axes), an 333 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce 334 a given set of axes), or a `Tensor` with a constant value. Must be in 335 the range `[0, input_tensor.rank]`. 336 name: A name prefix for the returned tensor (optional). 337 Returns: 338 A `RaggedTensor` containing the %(combined)s values. The returned tensor 339 has the same dtype as `data`, and its shape is given by removing the 340 dimensions specified in `axis` from `input_tensor.shape`. The `ragged_rank` 341 of the returned tensor is given by substracting any ragged dimensions 342 specified in `axis` from `input_tensor.ragged_rank`. 343 Raises: 344 ValueError: If `axis` contains a `Tensor` whose value is not constant. 345 ####Example: 346 ```python%(example)s ``` 347""" 348_RAGGED_REDUCE_SUM_EXAMPLE = """ 349 >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 350 >>> ragged.reduce_sum(rt, axis=0).eval().tolist() 351 [15, 12, 4] # = [3+1+9+2, 1+5+6, 4] 352 >>> ragged.reduce_sum(rt, axis=1).eval().tolist() 353 [8, 6, 9, 8] # = [3+1+4, 1+5, 9, 2+6] 354""" 355_RAGGED_REDUCE_PROD_EXAMPLE = """ 356 >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 357 >>> ragged.reduce_prod(rt, axis=0).eval().tolist() 358 [54, 30, 4] # = [3*1*9*2, 1*5*6, 4] 359 >>> ragged.reduce_prod(rt, axis=1).eval().tolist() 360 [12, 5, 9, 12] # = [3*1*4, 1*5, 9, 2*6] 361""" 362_RAGGED_REDUCE_MIN_EXAMPLE = """ 363 >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 364 >>> ragged.reduce_min(rt, axis=0).eval().tolist() 365 [1, 1, 4] # = [min(3, 1, 9, 2), min(1, 5, 6), 4] 366 >>> ragged.reduce_min(rt, axis=1).eval().tolist() 367 [1, 1, 9, 2] # = [min(3, 1, 4), min(1, 5), 9, min(2, 6)] 368""" 369_RAGGED_REDUCE_MAX_EXAMPLE = """ 370 >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 371 >>> ragged.reduce_max(rt, axis=0).eval().tolist() 372 [9, 6, 4] # = [max(3, 1, 9, 2), max(1, 5, 6), 4] 373 >>> ragged.reduce_max(rt, axis=1).eval().tolist() 374 [4, 5, 9, 6] # = [max(3, 1, 4), max(1, 5), 9, max(2, 6)] 375""" 376_RAGGED_REDUCE_MEAN_EXAMPLE = """ 377 >>> rt = ragged.constant([[3, 1, 4], [1, 5], [9], [2, 6]]) 378 >>> ragged.reduce_mean(rt, axis=0).eval().tolist() 379 [3.75, 4, 4] # = [mean(3, 1, 9, 2), mean(1, 5, 6), 4] 380 >>> ragged.reduce_mean(rt, axis=1).eval().tolist() 381 [2.66666, 3, 9, 4] # = [mean(3, 1, 4), mean(1, 5), 9, mean(2, 6)] 382""" 383_RAGGED_REDUCE_ALL_EXAMPLE = """ 384 >>> rt = ragged.constant([[True, True], [True, True, False, True], [False, True]]) 385 >>> ragged.reduce_all(rt, axis=0).eval().tolist() 386 [False, True, False, True] 387 >>> ragged.reduce_all(rt, axis=1).eval().tolist() 388 [True, False, False] 389""" 390_RAGGED_REDUCE_ANY_EXAMPLE = """ 391 >>> rt = ragged.constant([[True, True], [True, True, False, True], [False, True]]) 392 >>> ragged.reduce_any(rt, axis=0).eval().tolist() 393 [True, True, False, True] 394 >>> ragged.reduce_any(rt, axis=1).eval().tolist() 395 [True, True, True] 396""" 397 398 399def _ragged_reduce_aggregate(reduce_op, 400 unsorted_segment_op, 401 rt_input, 402 axis, 403 keepdims, 404 name=None): 405 """Aggregates across axes of a RaggedTensor using the given `Tensor` ops. 406 407 Reduces `rt_input` along the dimensions given in `axis`. The rank of the 408 tensor is reduced by 1 for each entry in `axis`. If `axis` is not specified, 409 then all dimensions are reduced, and a scalar value is returned. 410 411 This op assumes that `reduce_op` and `unsorted_segment_op` are associative; 412 if not, then reducing multiple axes will return incorrect results. (In 413 particular, reducing multiple axes is currently implemented by reducing the 414 axes one at a time.) 415 416 Args: 417 reduce_op: The tensorflow `op` that should be used to reduce values in 418 uniform dimensions. Must have the same signature and basic behavior as 419 `reduce_sum`, `reduce_max`, etc. 420 unsorted_segment_op: The tensorflow `op` that should be used to combine 421 values in ragged dimensions. Must have the same signature and basic 422 behavior as `unsorted_segment_sum`, `unsorted_segment_max`, etc. 423 rt_input: A `Tensor` or `RaggedTensor` containing the values to be reduced. 424 axis: The axis or axes to reduce. May be `None` (to reduce all axes), an 425 `int` (to reduce a single axis), a `list` or `tuple` of `int` (to reduce a 426 given set of axes), or a `Tensor` with a constant value. Must be in the 427 range `[0, rt_input.rank)`. 428 keepdims: If true, retains reduced dimensions with length 1. 429 name: A name prefix for the returned tensor (optional). 430 431 Returns: 432 A `RaggedTensor` containing the reduced values. The returned tensor 433 has the same dtype as `data`, and its shape is given by removing the 434 dimensions specified in `axis` from `rt_input.shape`. The `ragged_rank` 435 of the returned tensor is given by substracting any ragged dimensions 436 specified in `axis` from `rt_input.ragged_rank`. 437 Raises: 438 ValueError: If `axis` contains a `Tensor` whose value is not constant. 439 """ 440 if not ragged_tensor.is_ragged(rt_input): 441 return reduce_op(rt_input, axis, name=name) 442 443 if keepdims: 444 raise ValueError('keepdims=True is not supported for RaggedTensors.') 445 446 if isinstance(axis, ops.Tensor): 447 axis = tensor_util.constant_value(axis) 448 if axis is None: 449 raise ValueError('axis must be known at graph construction time.') 450 if isinstance(axis, np.ndarray): 451 axis = axis.tolist() 452 453 # When reducing all axes, just ignore splits & reduce the inner values. 454 if axis is None: 455 return reduce_op(rt_input.flat_values, None, name=name) 456 457 with ops.name_scope(name, 'RaggedReduce', [rt_input, axis]): 458 if isinstance(axis, (tuple, list)): 459 if not axis: 460 return rt_input 461 elif len(axis) == 1: 462 axis = axis[0] 463 else: 464 # When reducing multiple axes, just reduce one at a time. This is less 465 # efficient, and only works for associative ops. (In particular, it 466 # does not work for reduce_mean.) However, reducing multiple axes at 467 # once will probably require a nontrivial c++ op. 468 axis = sorted(axis) 469 inner_reduced = _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 470 rt_input, axis[-1], keepdims) 471 return _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 472 inner_reduced, axis[:-1], keepdims) 473 474 rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor( 475 rt_input, name='rt_input') 476 477 axis = ragged_util.get_positive_axis(axis, rt_input.shape.ndims) 478 479 if axis == 0: 480 # out[i_1, i_2, ..., i_N] = sum_{j} rt_input[j, i_1, i_2, ..., i_N] 481 row_lengths = rt_input.row_splits[1:] - rt_input.row_splits[:-1] 482 num_segments = math_ops.maximum(math_ops.reduce_max(row_lengths), 0) 483 segment_ids = range(row_lengths).values 484 return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 485 segment_ids, num_segments) 486 elif axis == 1: 487 # out[i_0, i_1, i_2, ..., i_N] = sum_{j} rt_input[i_0, j, i_2, ..., i_N] 488 num_segments = array_ops.shape(rt_input.row_splits)[0] - 1 489 segment_ids = segment_id_ops.row_splits_to_segment_ids( 490 rt_input.row_splits) 491 return _ragged_segment_aggregate(unsorted_segment_op, rt_input.values, 492 segment_ids, num_segments) 493 else: 494 # out[i_0, ..., i_[axis-1], i_axis+1], ..., i_N] = 495 # sum_{j} rt_input [i_0, ..., i_[axis-1], j, i_axis+1], ..., i_N] 496 return rt_input.with_values( 497 _ragged_reduce_aggregate(reduce_op, unsorted_segment_op, 498 rt_input.values, axis - 1, keepdims)) 499 500 501def reduce_sum(input_tensor, axis=None, keepdims=None, name=None): 502 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 503 return _ragged_reduce_aggregate(math_ops.reduce_sum, 504 math_ops.unsorted_segment_sum, input_tensor, 505 axis, keepdims, name or 'RaggedReduceSum') 506 507 508def reduce_prod(input_tensor, axis=None, keepdims=None, name=None): 509 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 510 return _ragged_reduce_aggregate(math_ops.reduce_prod, 511 math_ops.unsorted_segment_prod, input_tensor, 512 axis, keepdims, name or 'RaggedReduceProd') 513 514 515def reduce_min(input_tensor, axis=None, keepdims=None, name=None): 516 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 517 return _ragged_reduce_aggregate(math_ops.reduce_min, 518 math_ops.unsorted_segment_min, input_tensor, 519 axis, keepdims, name or 'RaggedReduceMin') 520 521 522def reduce_max(input_tensor, axis=None, keepdims=None, name=None): 523 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 524 return _ragged_reduce_aggregate(math_ops.reduce_max, 525 math_ops.unsorted_segment_max, input_tensor, 526 axis, keepdims, name or 'RaggedReduceMax') 527 528 529def reduce_mean(input_tensor, axis=None, keepdims=None, name=None): 530 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 531 with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]): 532 total = reduce_sum(input_tensor, axis, keepdims) 533 if ragged_tensor.is_ragged(input_tensor): 534 ones = ragged_tensor.RaggedTensor.from_nested_row_splits( 535 array_ops.ones_like(input_tensor.flat_values), 536 input_tensor.nested_row_splits) 537 else: 538 ones = array_ops.ones_like(input_tensor) 539 count = reduce_sum(ones, axis, keepdims) 540 if ragged_tensor.is_ragged(total): 541 return ragged_tensor.RaggedTensor.from_nested_row_splits( 542 total.flat_values / count.flat_values, total.nested_row_splits) 543 else: 544 return total / count 545 546 547def _cast(input_tensor, dtype): 548 return ragged_functional_ops.map_flat_values(math_ops.cast, input_tensor, 549 dtype) 550 551 552def reduce_all(input_tensor, axis=None, keepdims=None, name=None): 553 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 554 with ops.name_scope(name, 'RaggedReduceAll', [input_tensor, axis]): 555 return _cast( 556 reduce_prod(_cast(input_tensor, dtypes.int32), axis, keepdims), 557 dtypes.bool) 558 559 560def reduce_any(input_tensor, axis=None, keepdims=None, name=None): 561 """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" 562 with ops.name_scope(name, 'RaggedReduceAny', [input_tensor, axis]): 563 return _cast( 564 reduce_sum(_cast(input_tensor, dtypes.int32), axis, keepdims), 565 dtypes.bool) 566 567 568def _set_ragged_reduce_docstring(func, combination, combined, default, example): 569 func.__doc__ = _RAGGED_REDUCE_DOCSTRING % dict( 570 combination=combination, 571 combined=combined, 572 default=default, 573 example=example) 574 575 576_set_ragged_reduce_docstring(reduce_sum, 'sum', 'summed', '0', 577 _RAGGED_REDUCE_SUM_EXAMPLE) 578_set_ragged_reduce_docstring(reduce_prod, 'product', 'multiplied', '1', 579 _RAGGED_REDUCE_PROD_EXAMPLE) 580_set_ragged_reduce_docstring(reduce_min, 'minimum', 'minimized', 581 '`input_tensor.dtype.min`', 582 _RAGGED_REDUCE_MIN_EXAMPLE) 583_set_ragged_reduce_docstring(reduce_max, 'maximum', 'maximized', 584 '`input_tensor.dtype.max`', 585 _RAGGED_REDUCE_MAX_EXAMPLE) 586_set_ragged_reduce_docstring(reduce_mean, 'mean', 'averaged', 'NaN', 587 _RAGGED_REDUCE_MEAN_EXAMPLE) 588 589_set_ragged_reduce_docstring(reduce_all, 'logical and', 'and-ed', 'True', 590 _RAGGED_REDUCE_ALL_EXAMPLE) 591_set_ragged_reduce_docstring(reduce_any, 'logical or', 'or-ed', 'False', 592 _RAGGED_REDUCE_ANY_EXAMPLE) 593