• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Value for RaggedTensor."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.util.tf_export import tf_export
24
25
26@tf_export(v1=["ragged.RaggedTensorValue"])
27class RaggedTensorValue(object):
28  """Represents the value of a `RaggedTensor`.
29
30  Warning: `RaggedTensorValue` should only be used in graph mode; in
31  eager mode, the `tf.RaggedTensor` class contains its value directly.
32
33  See `tf.RaggedTensor` for a description of ragged tensors.
34  """
35
36  def __init__(self, values, row_splits):
37    """Creates a `RaggedTensorValue`.
38
39    Args:
40      values: A numpy array of any type and shape; or a RaggedTensorValue.
41      row_splits: A 1-D int64 numpy array.
42    """
43    if not (isinstance(row_splits, (np.ndarray, np.generic)) and
44            row_splits.dtype == np.int64 and row_splits.ndim == 1):
45      raise TypeError("row_splits must be a 1D int64 numpy array")
46    if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)):
47      raise TypeError("values must be a numpy array or a RaggedTensorValue")
48    self._values = values
49    self._row_splits = row_splits
50
51  row_splits = property(
52      lambda self: self._row_splits,
53      doc="""The split indices for the ragged tensor value.""")
54  values = property(
55      lambda self: self._values,
56      doc="""The concatenated values for all rows in this tensor.""")
57  dtype = property(
58      lambda self: self._values.dtype,
59      doc="""The numpy dtype of values in this tensor.""")
60
61  @property
62  def flat_values(self):
63    """The innermost `values` array for this ragged tensor value."""
64    rt_values = self.values
65    while isinstance(rt_values, RaggedTensorValue):
66      rt_values = rt_values.values
67    return rt_values
68
69  @property
70  def nested_row_splits(self):
71    """The row_splits for all ragged dimensions in this ragged tensor value."""
72    rt_nested_splits = [self.row_splits]
73    rt_values = self.values
74    while isinstance(rt_values, RaggedTensorValue):
75      rt_nested_splits.append(rt_values.row_splits)
76      rt_values = rt_values.values
77    return tuple(rt_nested_splits)
78
79  @property
80  def ragged_rank(self):
81    """The number of ragged dimensions in this ragged tensor value."""
82    values_is_ragged = isinstance(self._values, RaggedTensorValue)
83    return self._values.ragged_rank + 1 if values_is_ragged else 1
84
85  @property
86  def shape(self):
87    """A tuple indicating the shape of this RaggedTensorValue."""
88    return (self._row_splits.shape[0] - 1,) + (None,) + self._values.shape[1:]
89
90  def __str__(self):
91    return "<tf.RaggedTensorValue %s>" % self.to_list()
92
93  def __repr__(self):
94    return "tf.RaggedTensorValue(values=%r, row_splits=%r)" % (self._values,
95                                                               self._row_splits)
96
97  def to_list(self):
98    """Returns this ragged tensor value as a nested Python list."""
99    if isinstance(self._values, RaggedTensorValue):
100      values_as_list = self._values.to_list()
101    else:
102      values_as_list = self._values.tolist()
103    return [
104        values_as_list[self._row_splits[i]:self._row_splits[i + 1]]
105        for i in range(len(self._row_splits) - 1)
106    ]
107