1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Private convenience functions for RaggedTensors. 16 17None of these methods are exposed in the main "ragged" package. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import gen_ragged_math_ops 29from tensorflow.python.ops import math_ops 30 31 32def convert_to_int_tensor(tensor, name, dtype=dtypes.int32): 33 """Converts the given value to an integer Tensor.""" 34 tensor = ops.convert_to_tensor(tensor, name=name, preferred_dtype=dtype) 35 if tensor.dtype.is_integer: 36 tensor = math_ops.cast(tensor, dtype) 37 else: 38 raise TypeError( 39 "%s must be an integer tensor; dtype=%s" % (name, tensor.dtype)) 40 return tensor 41 42 43def get_positive_axis(axis, ndims): 44 """Validate an `axis` parameter, and normalize it to be positive. 45 46 If `ndims` is known (i.e., not `None`), then check that `axis` is in the 47 range `-ndims <= axis < ndims`, and return `axis` (if `axis >= 0`) or 48 `axis + ndims` (otherwise). 49 If `ndims` is not known, and `axis` is positive, then return it as-is. 50 If `ndims` is not known, and `axis` is negative, then report an error. 51 52 Args: 53 axis: An integer constant 54 ndims: An integer constant, or `None` 55 56 Returns: 57 The normalized `axis` value. 58 59 Raises: 60 ValueError: If `axis` is out-of-bounds, or if `axis` is negative and 61 `ndims is None`. 62 """ 63 if not isinstance(axis, int): 64 raise TypeError("axis must be an int; got %s" % type(axis).__name__) 65 if ndims is not None: 66 if 0 <= axis < ndims: 67 return axis 68 elif -ndims <= axis < 0: 69 return axis + ndims 70 else: 71 raise ValueError( 72 "axis=%s out of bounds: expected %s<=axis<%s" % (axis, -ndims, ndims)) 73 elif axis < 0: 74 raise ValueError("axis may only be negative if ndims is statically known.") 75 return axis 76 77 78def assert_splits_match(nested_splits_lists): 79 """Checks that the given splits lists are identical. 80 81 Performs static tests to ensure that the given splits lists are identical, 82 and returns a list of control dependency op tensors that check that they are 83 fully identical. 84 85 Args: 86 nested_splits_lists: A list of nested_splits_lists, where each split_list is 87 a list of `splits` tensors from a `RaggedTensor`, ordered from outermost 88 ragged dimension to innermost ragged dimension. 89 90 Returns: 91 A list of control dependency op tensors. 92 Raises: 93 ValueError: If the splits are not identical. 94 """ 95 error_msg = "Inputs must have identical ragged splits" 96 for splits_list in nested_splits_lists: 97 if len(splits_list) != len(nested_splits_lists[0]): 98 raise ValueError(error_msg) 99 return [ 100 check_ops.assert_equal(s1, s2, message=error_msg) 101 for splits_list in nested_splits_lists[1:] 102 for (s1, s2) in zip(nested_splits_lists[0], splits_list) 103 ] 104 105 106# This op is intended to exactly match the semantics of numpy.repeat, with 107# one exception: numpy.repeat has special (and somewhat non-intuitive) behavior 108# when axis is not specified. Rather than implement that special behavior, we 109# simply make `axis` be a required argument. 110# 111# External (OSS) `tf.repeat` feature request: 112# https://github.com/tensorflow/tensorflow/issues/8246 113def repeat(data, repeats, axis, name=None): 114 """Repeats elements of `data`. 115 116 Args: 117 data: An `N`-dimensional tensor. 118 repeats: A 1-D integer tensor specifying how many times each element in 119 `axis` should be repeated. `len(repeats)` must equal `data.shape[axis]`. 120 Supports broadcasting from a scalar value. 121 axis: `int`. The axis along which to repeat values. Must be less than 122 `max(N, 1)`. 123 name: A name for the operation. 124 125 Returns: 126 A tensor with `max(N, 1)` dimensions. Has the same shape as `data`, 127 except that dimension `axis` has size `sum(repeats)`. 128 129 #### Examples: 130 ```python 131 >>> repeat(['a', 'b', 'c'], repeats=[3, 0, 2], axis=0) 132 ['a', 'a', 'a', 'c', 'c'] 133 >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=0) 134 [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4]] 135 >>> repeat([[1, 2], [3, 4]], repeats=[2, 3], axis=1) 136 [[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]] 137 ``` 138 """ 139 if not isinstance(axis, int): 140 raise TypeError("axis must be an int; got %s" % type(axis).__name__) 141 142 with ops.name_scope(name, "Repeat", [data, repeats]): 143 data = ops.convert_to_tensor(data, name="data") 144 repeats = convert_to_int_tensor(repeats, name="repeats") 145 repeats.shape.with_rank_at_most(1) 146 147 # If `data` is a scalar, then upgrade it to a vector. 148 data = _with_nonzero_rank(data) 149 data_shape = array_ops.shape(data) 150 151 # If `axis` is negative, then convert it to a positive value. 152 axis = get_positive_axis(axis, data.shape.ndims) 153 154 # Check data Tensor shapes. 155 if repeats.shape.ndims == 1: 156 data.shape.dims[axis].assert_is_compatible_with(repeats.shape[0]) 157 158 # If we know that `repeats` is a scalar, then we can just tile & reshape. 159 if repeats.shape.ndims == 0: 160 expanded = array_ops.expand_dims(data, axis + 1) 161 tiled = tile_one_dimension(expanded, axis + 1, repeats) 162 result_shape = array_ops.concat( 163 [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) 164 return array_ops.reshape(tiled, result_shape) 165 166 # Broadcast the `repeats` tensor so rank(repeats) == axis + 1. 167 if repeats.shape.ndims != axis + 1: 168 repeats_shape = array_ops.shape(repeats) 169 repeats_ndims = array_ops.rank(repeats) 170 broadcast_shape = array_ops.concat( 171 [data_shape[:axis + 1 - repeats_ndims], repeats_shape], axis=0) 172 repeats = array_ops.broadcast_to(repeats, broadcast_shape) 173 repeats.set_shape([None] * (axis + 1)) 174 175 # Create a "sequence mask" based on `repeats`, where slices across `axis` 176 # contain one `True` value for each repetition. E.g., if 177 # `repeats = [3, 1, 2]`, then `mask = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]`. 178 max_repeat = math_ops.maximum(0, math_ops.reduce_max(repeats)) 179 mask = array_ops.sequence_mask(repeats, max_repeat) 180 181 # Add a new dimension around each value that needs to be repeated, and 182 # then tile that new dimension to match the maximum number of repetitions. 183 expanded = array_ops.expand_dims(data, axis + 1) 184 tiled = tile_one_dimension(expanded, axis + 1, max_repeat) 185 186 # Use `boolean_mask` to discard the extra repeated values. This also 187 # flattens all dimensions up through `axis`. 188 masked = array_ops.boolean_mask(tiled, mask) 189 190 # Reshape the output tensor to add the outer dimensions back. 191 if axis == 0: 192 result = masked 193 else: 194 result_shape = array_ops.concat( 195 [data_shape[:axis], [-1], data_shape[axis + 1:]], axis=0) 196 result = array_ops.reshape(masked, result_shape) 197 198 # Preserve shape information. 199 if data.shape.ndims is not None: 200 new_axis_size = 0 if repeats.shape[0] == 0 else None 201 result.set_shape(data.shape[:axis].concatenate( 202 [new_axis_size]).concatenate(data.shape[axis + 1:])) 203 204 return result 205 206 207def tile_one_dimension(data, axis, multiple): 208 """Tiles a single dimension of a tensor.""" 209 # Assumes axis is a nonnegative int. 210 if data.shape.ndims is not None: 211 multiples = [1] * data.shape.ndims 212 multiples[axis] = multiple 213 else: 214 ones = array_ops.ones(array_ops.rank(data), dtypes.int32) 215 multiples = array_ops.concat([ones[:axis], [multiple], ones[axis + 1:]], 216 axis=0) 217 return array_ops.tile(data, multiples) 218 219 220def _with_nonzero_rank(data): 221 """If `data` is scalar, then add a dimension; otherwise return as-is.""" 222 if data.shape.ndims is not None: 223 if data.shape.ndims == 0: 224 return array_ops.stack([data]) 225 else: 226 return data 227 else: 228 data_shape = array_ops.shape(data) 229 data_ndims = array_ops.rank(data) 230 return array_ops.reshape( 231 data, 232 array_ops.concat([[1], data_shape], axis=0)[-data_ndims:]) 233 234 235def lengths_to_splits(lengths): 236 """Returns splits corresponding to the given lengths.""" 237 return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1) 238 239 240def repeat_ranges(params, splits, repeats): 241 """Repeats each range of `params` (as specified by `splits`) `repeats` times. 242 243 Let the `i`th range of `params` be defined as 244 `params[splits[i]:splits[i + 1]]`. Then this function returns a tensor 245 containing range 0 repeated `repeats[0]` times, followed by range 1 repeated 246 `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times. 247 248 Args: 249 params: The `Tensor` whose values should be repeated. 250 splits: A splits tensor indicating the ranges of `params` that should be 251 repeated. 252 repeats: The number of times each range should be repeated. Supports 253 broadcasting from a scalar value. 254 255 Returns: 256 A `Tensor` with the same rank and type as `params`. 257 258 #### Example: 259 ```python 260 >>> repeat_ranges(['a', 'b', 'c'], [0, 2, 3], 3) 261 ['a', 'b', 'a', 'b', 'a', 'b', 'c', 'c', 'c'] 262 ``` 263 """ 264 # Divide `splits` into starts and limits, and repeat them `repeats` times. 265 if repeats.shape.ndims != 0: 266 repeated_starts = repeat(splits[:-1], repeats, axis=0) 267 repeated_limits = repeat(splits[1:], repeats, axis=0) 268 else: 269 # Optimization: we can just call repeat once, and then slice the result. 270 repeated_splits = repeat(splits, repeats, axis=0) 271 n_splits = array_ops.shape(repeated_splits, out_type=dtypes.int64)[0] 272 repeated_starts = repeated_splits[:n_splits - repeats] 273 repeated_limits = repeated_splits[repeats:] 274 275 # Get indices for each range from starts to limits, and use those to gather 276 # the values in the desired repetition pattern. 277 one = array_ops.ones((), repeated_starts.dtype) 278 offsets = gen_ragged_math_ops.ragged_range( 279 repeated_starts, repeated_limits, one) 280 return array_ops.gather(params, offsets.rt_dense_values) 281