• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops to convert between RaggedTensors and other tensor types."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import indexed_slices
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gen_ragged_conversion_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.ragged import ragged_tensor
28
29
30def from_tensor(tensor,
31                lengths=None,
32                padding=None,
33                ragged_rank=1,
34                row_splits_dtype=dtypes.int64,
35                name=None):
36  if ragged_tensor.is_ragged(tensor):
37    return tensor
38  else:
39    return ragged_tensor.RaggedTensor.from_tensor(
40        tensor,
41        lengths=lengths,
42        padding=padding,
43        ragged_rank=ragged_rank,
44        row_splits_dtype=row_splits_dtype,
45        name=name)
46
47
48def to_tensor(rt_input, default_value=None, name=None):
49  if ragged_tensor.is_ragged(rt_input):
50    return rt_input.to_tensor(default_value, name)
51  else:
52    return rt_input
53
54
55def ragged_to_dense(rt_input, default_value=None, shape=None):
56  """Create a dense tensor from a ragged tensor."""
57  return rt_input.to_tensor(default_value=default_value, shape=shape)
58
59
60@ops.RegisterGradient("RaggedTensorToTensor")
61def _ragged_tensor_to_tensor_grad(op, grad):
62  """Gradient for RaggedToTensor op."""
63  # Extract inputs from the op.
64  flat_values = op.inputs[1]
65  default_value = op.inputs[2]
66  row_partition_tensors = op.inputs[3:]
67  row_partition_types = op.get_attr("row_partition_types")
68  flat_value_shape = array_ops.shape(flat_values)
69  ragged_rank = sum(
70      1 for typ in row_partition_types if typ != b"FIRST_DIM_SIZE")
71
72  # Create two tensors that correspond 1:1 with grad (and op.output):
73  # * indices[i1...iN] is the index in `flat_values` of the value used to
74  #   populate output[i1...iN] (if the value came from `flat_values`) or
75  #   -1 (if the value came from `default_value`).
76  # * mask[i1...iN] is true if output[i1...iN] came from `flat_values`, or
77  #   false if it came from `default_value`.
78  indices = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
79      shape=array_ops.shape(grad)[:1 + ragged_rank],
80      values=math_ops.range(flat_value_shape[0]),
81      default_value=-1,
82      row_partition_types=row_partition_types,
83      row_partition_tensors=row_partition_tensors)
84  mask = math_ops.not_equal(indices, -1)
85
86  # Select out the gradients & indices that came from `flat_values`, and use
87  # those to construct the gradient for `flat_values` (as an IndexedSlices).
88  values_grad = indexed_slices.IndexedSlices(
89      values=array_ops.boolean_mask(grad, mask),
90      indices=array_ops.boolean_mask(indices, mask),
91      dense_shape=flat_value_shape)
92
93  # Select out the gradients that came from `default_value`, and sum them to
94  # get the gradient for the default.  Note that the default_value may have
95  # been broadcast as part of the RaggedTensorToTensor operation, so we also
96  # need to reduce any dimensions that might have been broadcast.
97  default_grads = array_ops.boolean_mask(grad, ~mask)
98  dims_to_reduce = math_ops.range(
99      array_ops.rank(default_grads) -
100      _rank_ignoring_leading_dims_with_size_1(default_value))
101  default_grad = math_ops.reduce_sum(default_grads, axis=dims_to_reduce)
102
103  # Restore any leading dims with size one.
104  default_grad = array_ops.reshape(default_grad, array_ops.shape(default_value))
105
106  return ([None, values_grad, default_grad] +
107          [None for _ in row_partition_tensors])
108
109
110def _rank_ignoring_leading_dims_with_size_1(value):
111  """Returns `rank(value)`, ignoring any leading dimensions with size 1."""
112  # Compute the result using static shape, if possible.
113  if value.shape.rank is not None:
114    ndims = value.shape.rank
115    for dim in value.shape.dims:
116      if dim.value == 1:
117        ndims -= 1
118      elif dim.value is None:
119        ndims = None  # Can't compute the result using static shape.
120        break
121      else:
122        break
123    if ndims is not None:
124      return ndims
125
126  # Otherwise, we need to compute the result dynamically.  The math we use to
127  # do this is a bit round-about, so here's an example to illustrate:
128  #              shape = [1, 1, 3, 5, 1, 4]  # shape(value)
129  #         dim_is_one = [1, 1, 0, 0, 1, 0]  # equal(shape, 1)
130  #       leading_ones = [1, 1, 0, 0, 0, 0]  # cumprod(dim_is_one)
131  #   num_leading_ones = 2                   # reduce_sum(leading_ones)
132  #             result = 4                   # rank(value) - num_leading_ones
133  shape = array_ops.shape(value)
134  dim_is_one = math_ops.cast(math_ops.equal(shape, 1), dtypes.int32)
135  leading_ones = math_ops.cumprod(dim_is_one)
136  num_leading_ones = math_ops.reduce_sum(leading_ones)
137  return array_ops.rank(value) - num_leading_ones
138
139
140def to_sparse(rt_input, name=None):
141  return rt_input.to_sparse(name)
142
143
144def from_sparse(st_input, name=None):
145  return ragged_tensor.RaggedTensor.from_sparse(st_input, name)
146
147
148@ops.RegisterGradient("RaggedTensorFromVariant")
149def _ragged_tensor_from_variant_grad(op, *grads):
150  """Gradient for RaggedTensorFromVariant op."""
151
152  variant_rank = op.inputs[0].shape.rank
153  if variant_rank == 0:
154    batched_input = False
155  elif variant_rank == 1:
156    batched_input = True
157  elif variant_rank is None:
158    batched_input = (op.get_attr("output_ragged_rank") > 0)
159  else:
160    # TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so
161    # we can support this.
162    raise ValueError("Unable to compute gradient: RaggedTensorToVariant "
163                     "can currently only generate 0D or 1D output.")
164  return [
165      gen_ragged_conversion_ops.ragged_tensor_to_variant(
166          rt_nested_splits=op.outputs[:-1],
167          rt_dense_values=grads[-1],
168          batched_input=batched_input)
169  ]
170
171
172@ops.RegisterGradient("RaggedTensorToVariant")
173def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad):
174  """Gradient for RaggedTensorToVariant op."""
175  dense_values = op.inputs[-1]
176  ragged_rank = len(op.inputs) - 1
177  row_splits = 0 if ragged_rank == 0 else op.inputs[0]
178  values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient(
179      encoded_ragged_grad=encoded_ragged_grad,
180      row_splits=row_splits,
181      dense_values_shape=array_ops.shape(dense_values),
182      Tvalues=op.inputs[-1].dtype)
183  result = [None] * ragged_rank + [values_grad]
184  return result
185