1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Value for RaggedTensor.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.util.tf_export import tf_export 24 25 26@tf_export(v1=["ragged.RaggedTensorValue"]) 27class RaggedTensorValue(object): 28 """Represents the value of a `RaggedTensor`. 29 30 Warning: `RaggedTensorValue` should only be used in graph mode; in 31 eager mode, the `tf.RaggedTensor` class contains its value directly. 32 33 See `tf.RaggedTensor` for a description of ragged tensors. 34 """ 35 36 def __init__(self, values, row_splits): 37 """Creates a `RaggedTensorValue`. 38 39 Args: 40 values: A numpy array of any type and shape; or a RaggedTensorValue. 41 row_splits: A 1-D int64 numpy array. 42 """ 43 if not (isinstance(row_splits, (np.ndarray, np.generic)) and 44 row_splits.dtype == np.int64 and row_splits.ndim == 1): 45 raise TypeError("row_splits must be a 1D int64 numpy array") 46 if not isinstance(values, (np.ndarray, np.generic, RaggedTensorValue)): 47 raise TypeError("values must be a numpy array or a RaggedTensorValue") 48 self._values = values 49 self._row_splits = row_splits 50 51 row_splits = property( 52 lambda self: self._row_splits, 53 doc="""The split indices for the ragged tensor value.""") 54 values = property( 55 lambda self: self._values, 56 doc="""The concatenated values for all rows in this tensor.""") 57 dtype = property( 58 lambda self: self._values.dtype, 59 doc="""The numpy dtype of values in this tensor.""") 60 61 @property 62 def flat_values(self): 63 """The innermost `values` array for this ragged tensor value.""" 64 rt_values = self.values 65 while isinstance(rt_values, RaggedTensorValue): 66 rt_values = rt_values.values 67 return rt_values 68 69 @property 70 def nested_row_splits(self): 71 """The row_splits for all ragged dimensions in this ragged tensor value.""" 72 rt_nested_splits = [self.row_splits] 73 rt_values = self.values 74 while isinstance(rt_values, RaggedTensorValue): 75 rt_nested_splits.append(rt_values.row_splits) 76 rt_values = rt_values.values 77 return tuple(rt_nested_splits) 78 79 @property 80 def ragged_rank(self): 81 """The number of ragged dimensions in this ragged tensor value.""" 82 values_is_ragged = isinstance(self._values, RaggedTensorValue) 83 return self._values.ragged_rank + 1 if values_is_ragged else 1 84 85 @property 86 def shape(self): 87 """A tuple indicating the shape of this RaggedTensorValue.""" 88 return (self._row_splits.shape[0] - 1,) + (None,) + self._values.shape[1:] 89 90 def __str__(self): 91 return "<tf.RaggedTensorValue %s>" % self.to_list() 92 93 def __repr__(self): 94 return "tf.RaggedTensorValue(values=%r, row_splits=%r)" % (self._values, 95 self._row_splits) 96 97 def to_list(self): 98 """Returns this ragged tensor value as a nested Python list.""" 99 if isinstance(self._values, RaggedTensorValue): 100 values_as_list = self._values.to_list() 101 else: 102 values_as_list = self._values.tolist() 103 return [ 104 values_as_list[self._row_splits[i]:self._row_splits[i + 1]] 105 for i in range(len(self._row_splits) - 1) 106 ] 107