1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""where operation for RaggedTensors.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.ragged import ragged_concat_ops 25from tensorflow.python.ops.ragged import ragged_functional_ops 26from tensorflow.python.ops.ragged import ragged_gather_ops 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.ops.ragged import ragged_tensor_shape 29 30 31def where_v2(condition, x=None, y=None, name=None): 32 """Return the elements where `condition` is `True`. 33 34 : If both `x` and `y` are None: Retrieve indices of true elements. 35 36 Returns the coordinates of true elements of `condition`. The coordinates 37 are returned in a 2-D tensor with shape 38 `[num_true_values, dim_size(condition)]`, where `result[i]` is the 39 coordinates of the `i`th true value (in row-major order). 40 41 : If both `x` and `y` are non-`None`: Multiplex between `x` and `y`. 42 43 Choose an output shape from the shapes of `condition`, `x`, and `y` that 44 all three shapes are broadcastable to; and then use the broadcasted 45 `condition` tensor as a mask that chooses whether the corredsponding element 46 in the output should be taken from `x` (if `condition` is true) or `y` (if 47 `condition` is false). 48 49 >>> # Example: retrieve indices of true elements 50 >>> tf.where(tf.ragged.constant([[True, False], [True]])) 51 <tf.Tensor: shape=(2, 2), dtype=int64, numpy= array([[0, 0], [1, 0]])> 52 53 >>> # Example: multiplex between `x` and `y` 54 >>> tf.where(tf.ragged.constant([[True, False], [True, False, True]]), 55 ... tf.ragged.constant([['A', 'B'], ['C', 'D', 'E']]), 56 ... tf.ragged.constant([['a', 'b'], ['c', 'd', 'e']])) 57 <tf.RaggedTensor [[b'A', b'b'], [b'C', b'd', b'E']]> 58 59 Args: 60 condition: A potentially ragged tensor of type `bool` 61 x: A potentially ragged tensor (optional). 62 y: A potentially ragged tensor (optional). Must be specified if `x` is 63 specified. Must have the same rank and type as `x`. 64 name: A name of the operation (optional). 65 66 Returns: 67 : If both `x` and `y` are `None`: 68 A `Tensor` with shape `(num_true, rank(condition))`. 69 : Otherwise: 70 A potentially ragged tensor with the same type as `x` and `y`, and whose 71 shape is broadcast-compatible with `x`, `y`, and `condition`. 72 73 Raises: 74 ValueError: When exactly one of `x` or `y` is non-`None`; or when 75 `condition`, `x`, and `y` have incompatible shapes. 76 """ 77 if (x is None) != (y is None): 78 raise ValueError('x and y must be either both None or both non-None') 79 80 with ops.name_scope('RaggedWhere', name, [condition, x, y]): 81 condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( 82 condition, name='condition') 83 if x is None: 84 return _coordinate_where(condition) 85 else: 86 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 87 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') 88 condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y) 89 return _elementwise_where_v2(condition, x, y) 90 91 92def where(condition, x=None, y=None, name=None): 93 """Return the elements, either from `x` or `y`, depending on the `condition`. 94 95 : If both `x` and `y` are `None`: 96 Returns the coordinates of true elements of `condition`. The coordinates 97 are returned in a 2-D tensor with shape 98 `[num_true_values, dim_size(condition)]`, where `result[i]` is the 99 coordinates of the `i`th true value (in row-major order). 100 101 : If both `x` and `y` are non-`None`: 102 Returns a tensor formed by selecting values from `x` where condition is 103 true, and from `y` when condition is false. In particular: 104 105 : If `condition`, `x`, and `y` all have the same shape: 106 107 * `result[i1...iN] = x[i1...iN]` if `condition[i1...iN]` is true. 108 * `result[i1...iN] = y[i1...iN]` if `condition[i1...iN]` is false. 109 110 : Otherwise: 111 112 * `condition` must be a vector. 113 * `x` and `y` must have the same number of dimensions. 114 * The outermost dimensions of `condition`, `x`, and `y` must all have the 115 same size. 116 * `result[i] = x[i]` if `condition[i]` is true. 117 * `result[i] = y[i]` if `condition[i]` is false. 118 119 Args: 120 condition: A potentially ragged tensor of type `bool` 121 x: A potentially ragged tensor (optional). 122 y: A potentially ragged tensor (optional). Must be specified if `x` is 123 specified. Must have the same rank and type as `x`. 124 name: A name of the operation (optional) 125 126 Returns: 127 : If both `x` and `y` are `None`: 128 A `Tensor` with shape `(num_true, dim_size(condition))`. 129 : Otherwise: 130 A potentially ragged tensor with the same type, rank, and outermost 131 dimension size as `x` and `y`. 132 `result.ragged_rank = max(x.ragged_rank, y.ragged_rank)`. 133 134 Raises: 135 ValueError: When exactly one of `x` or `y` is non-`None`; or when 136 `condition`, `x`, and `y` have incompatible shapes. 137 138 #### Examples: 139 140 >>> # Coordinates where condition is true. 141 >>> condition = tf.ragged.constant([[True, False, True], [False, True]]) 142 >>> print(where(condition)) 143 tf.Tensor( [[0 0] [0 2] [1 1]], shape=(3, 2), dtype=int64) 144 145 >>> # Elementwise selection between x and y, based on condition. 146 >>> condition = tf.ragged.constant([[True, False, True], [False, True]]) 147 >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']]) 148 >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']]) 149 >>> print(where(condition, x, y)) 150 <tf.RaggedTensor [[b'A', b'b', b'C'], [b'd', b'E']]> 151 152 >>> # Row selection between x and y, based on condition. 153 >>> condition = [True, False] 154 >>> x = tf.ragged.constant([['A', 'B', 'C'], ['D', 'E']]) 155 >>> y = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e']]) 156 >>> print(where(condition, x, y)) 157 <tf.RaggedTensor [[b'A', b'B', b'C'], [b'd', b'e']]> 158 """ 159 if (x is None) != (y is None): 160 raise ValueError('x and y must be either both None or both non-None') 161 with ops.name_scope('RaggedWhere', name, [condition, x, y]): 162 condition = ragged_tensor.convert_to_tensor_or_ragged_tensor( 163 condition, name='condition') 164 if x is None: 165 return _coordinate_where(condition) 166 else: 167 x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, name='x') 168 y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, name='y') 169 condition, x, y = ragged_tensor.match_row_splits_dtypes(condition, x, y) 170 return _elementwise_where(condition, x, y) 171 172 173def _elementwise_where(condition, x, y): 174 """Ragged version of tf.where(condition, x, y).""" 175 condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) 176 x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) 177 y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) 178 179 if not (condition_is_ragged or x_is_ragged or y_is_ragged): 180 return array_ops.where(condition, x, y) 181 182 elif condition_is_ragged and x_is_ragged and y_is_ragged: 183 return ragged_functional_ops.map_flat_values(array_ops.where, condition, x, 184 y) 185 elif not condition_is_ragged: 186 # Concatenate x and y, and then use `gather` to assemble the selected rows. 187 condition.shape.assert_has_rank(1) 188 x_and_y = ragged_concat_ops.concat([x, y], axis=0) 189 x_nrows = _nrows(x, out_type=x_and_y.row_splits.dtype) 190 y_nrows = _nrows(y, out_type=x_and_y.row_splits.dtype) 191 indices = array_ops.where(condition, math_ops.range(x_nrows), 192 x_nrows + math_ops.range(y_nrows)) 193 return ragged_gather_ops.gather(x_and_y, indices) 194 195 else: 196 raise ValueError('Input shapes do not match.') 197 198 199def _elementwise_where_v2(condition, x, y): 200 """Ragged version of tf.where_v2(condition, x, y).""" 201 # Broadcast x, y, and condition to have the same shape. 202 if not (condition.shape.is_fully_defined() and x.shape.is_fully_defined() and 203 y.shape.is_fully_defined() and x.shape == y.shape and 204 condition.shape == x.shape): 205 shape_c = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor( 206 condition) 207 shape_x = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x) 208 shape_y = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y) 209 shape = ragged_tensor_shape.broadcast_dynamic_shape( 210 shape_c, ragged_tensor_shape.broadcast_dynamic_shape(shape_x, shape_y)) 211 condition = ragged_tensor_shape.broadcast_to(condition, shape) 212 x = ragged_tensor_shape.broadcast_to(x, shape) 213 y = ragged_tensor_shape.broadcast_to(y, shape) 214 215 condition_is_ragged = isinstance(condition, ragged_tensor.RaggedTensor) 216 x_is_ragged = isinstance(x, ragged_tensor.RaggedTensor) 217 y_is_ragged = isinstance(y, ragged_tensor.RaggedTensor) 218 if not (condition_is_ragged or x_is_ragged or y_is_ragged): 219 return array_ops.where_v2(condition, x, y) 220 221 return ragged_functional_ops.map_flat_values(array_ops.where_v2, condition, x, 222 y) 223 224 225def _coordinate_where(condition): 226 """Ragged version of tf.where(condition).""" 227 if not isinstance(condition, ragged_tensor.RaggedTensor): 228 return array_ops.where(condition) 229 230 # The coordinate for each `true` value in condition.values. 231 selected_coords = _coordinate_where(condition.values) 232 233 # Convert the first index in each coordinate to a row index and column index. 234 condition = condition.with_row_splits_dtype(selected_coords.dtype) 235 first_index = selected_coords[:, 0] 236 selected_rows = array_ops.gather(condition.value_rowids(), first_index) 237 selected_row_starts = array_ops.gather(condition.row_splits, selected_rows) 238 selected_cols = first_index - selected_row_starts 239 240 # Assemble the row & column index with the indices for inner dimensions. 241 return array_ops.concat([ 242 array_ops.expand_dims(selected_rows, 1), 243 array_ops.expand_dims(selected_cols, 1), selected_coords[:, 1:] 244 ], 245 axis=1) 246 247 248def _nrows(rt_input, out_type): 249 if isinstance(rt_input, ragged_tensor.RaggedTensor): 250 return rt_input.nrows(out_type=out_type) 251 else: 252 return array_ops.shape(rt_input, out_type=out_type)[0] 253