1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26
27 namespace tensorflow {
28 namespace {
29
RaggedComponentsFromVariant(const Tensor & encoded_variant,int ragged_rank,DataType value_dtype,DataType split_dtype,std::vector<RaggedTensorVariant> * decoded_ragged)30 Status RaggedComponentsFromVariant(
31 const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
32 DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
33 const auto& flat_variants = encoded_variant.flat<Variant>();
34 decoded_ragged->reserve(flat_variants.size());
35
36 for (int i = 0; i < flat_variants.size(); i++) {
37 const auto& flat_variant = flat_variants(i);
38 const RaggedTensorVariant* decoded =
39 flat_variant.get<RaggedTensorVariant>();
40 if (decoded == nullptr) {
41 return errors::InvalidArgument(
42 "Input Variant element at index ", i,
43 " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
44 }
45 decoded_ragged->push_back(*decoded);
46 decoded = &decoded_ragged->back();
47 // Check ragged rank & types
48 if (decoded->ragged_rank() != ragged_rank) {
49 return errors::InvalidArgument(
50 "Encoded input RaggedTensorVariant has ragged_rank=",
51 decoded->ragged_rank(), ". Expected ragged_rank=", ragged_rank, ".");
52 }
53 if (decoded->values().dtype() != value_dtype) {
54 return errors::InvalidArgument(
55 "Expected values Tensor dtype: ", DataTypeString(value_dtype),
56 ", found: ", DataTypeString(decoded->values().dtype()));
57 }
58 if (decoded->values().dims() < 1) {
59 return errors::InvalidArgument(
60 "Ragged values must have rank >= 1; encoded scalar element at index ",
61 i, " has values Tensor: ", decoded->values().DebugString());
62 }
63 for (const auto& splits : decoded->nested_splits()) {
64 if (splits.dtype() != split_dtype) {
65 return errors::InvalidArgument(
66 "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
67 ", found: ", DataTypeString(splits.dtype()));
68 }
69 if (splits.dims() != 1) {
70 return errors::InvalidArgument(
71 "Ragged splits must have rank 1; encoded scalar element at index ",
72 i, " has splits Tensor ", splits.DebugString());
73 }
74 }
75 }
76 return Status::OK();
77 }
78
79 template <typename VALUE_TYPE, typename SPLIT_TYPE>
NestedStackRaggedTensors(const std::vector<RaggedTensorVariant> & ragged_components,const std::vector<int> & nested_dim_sizes,const int input_ragged_rank,const int output_ragged_rank,RaggedTensorVariant * output_ragged)80 Status NestedStackRaggedTensors(
81 const std::vector<RaggedTensorVariant>& ragged_components,
82 const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
83 const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
84 output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
85 const int dims = nested_dim_sizes.size();
86
87 // Populate first `dims - 1` splits.
88 for (int i = 0; i < dims - 1; i++) {
89 int dims_splits_size = nested_dim_sizes[i] + 1;
90 output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
91 TensorShape({dims_splits_size})));
92 auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
93 int split_diff = nested_dim_sizes[i + 1];
94 for (int j = 0; j < dims_splits_size; j++) {
95 splits_vec(j) = j * split_diff;
96 }
97 }
98
99 // Populate `dims`-th split.
100 int splits_size = ragged_components.size() + 1;
101 output_ragged->append_splits(
102 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
103 auto dims_splits_vec =
104 output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
105 dims_splits_vec(0) = 0;
106 for (int i = 0; i < ragged_components.size(); i++) {
107 int split_val = ragged_components[i].values().shape().dim_size(0);
108 if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
109 split_val = ragged_components[i].splits(0).NumElements() - 1;
110 }
111 dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
112 }
113
114 // Populate last `input_ragged_rank` splits.
115 for (int i = 0; i < input_ragged_rank; i++) {
116 int split_index = dims + i;
117 int split_size = 1;
118 for (int j = 0; j < ragged_components.size(); j++) {
119 if (!ragged_components[j].nested_splits().empty()) {
120 split_size += ragged_components[j].splits(i).NumElements() - 1;
121 }
122 }
123 output_ragged->append_splits(
124 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
125 auto splits_vec =
126 output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
127 splits_vec(0) = 0;
128 SPLIT_TYPE last_split_value = 0;
129 int index = 1;
130 for (int j = 0; j < ragged_components.size(); j++) {
131 if (ragged_components[j].nested_splits().empty()) {
132 // Corner case: empty row. e.g [ [[x], [x]], [] ]
133 continue;
134 }
135 auto component_splits_vec =
136 ragged_components[j].splits(i).vec<SPLIT_TYPE>();
137 for (int k = 1; k < component_splits_vec.size(); k++, index++) {
138 splits_vec(index) = component_splits_vec(k) + last_split_value;
139 }
140 last_split_value = splits_vec(index - 1);
141 }
142 }
143
144 // If the variant tensor input is empty, then we have no way to determine
145 // the correct shape for the dense_values. (It must have rank>=1, and its
146 // outer dimension must be 0, but we don't know its shape beyond that.)
147 // For now, we just use a shape of `[0]` in this case.
148 // TODO(edloper): Update this op with an attribute containing information
149 // about dense_values shape. If it's `None`, then we'll probably still have
150 // to use shape=[0] here, but if we have more info, then we can use it.
151 // E.g., in map_fn, we may have shape info from the RaggedTensorSpec.
152 TensorShape component_values_shape;
153 if (ragged_components.empty()) {
154 component_values_shape = TensorShape({0});
155 } else {
156 component_values_shape = ragged_components[0].values().shape();
157 }
158
159 // Populate values.
160 int values_size = component_values_shape.dim_size(0);
161 for (int i = 1; i < ragged_components.size(); i++) {
162 if (ragged_components[i].values().dims() != component_values_shape.dims()) {
163 return errors::InvalidArgument(
164 "Rank of values must match for all "
165 "components; values shape at index 0: ",
166 component_values_shape.DebugString(), ", values shape at index ", i,
167 ": ", ragged_components[i].values().shape().DebugString());
168 }
169 values_size += ragged_components[i].values().shape().dim_size(0);
170 }
171 component_values_shape.set_dim(0, values_size);
172 output_ragged->set_values(
173 Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
174 auto output_values_flat =
175 output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
176 int values_index = 0;
177
178 TensorShape expected_value_shape = component_values_shape;
179 expected_value_shape.RemoveDim(0);
180
181 for (int i = 0; i < ragged_components.size(); i++) {
182 // Check that the flat_values tensor shape is compatible.
183 TensorShape value_shape = ragged_components[i].values().shape();
184 value_shape.RemoveDim(0);
185 if (value_shape != expected_value_shape) {
186 return errors::InvalidArgument(
187 "All flat_values must have compatible shapes. Shape at index 0: ",
188 expected_value_shape, ". Shape at index ", i, ": ", value_shape,
189 ". If you are using tf.map_fn, then you may need to specify an "
190 "explicit fn_output_signature with appropriate ragged_rank, and/or "
191 "convert output tensors to RaggedTensors.");
192 }
193
194 auto component_values_flat =
195 ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
196 int num_inner_elements = ragged_components[i].values().NumElements();
197 if (ragged_components[i].values().dim_size(0) > 0) {
198 num_inner_elements /= ragged_components[i].values().dim_size(0);
199 }
200 for (int j = 0; j < ragged_components[i].values().dim_size(0);
201 j++, values_index++) {
202 for (int k = 0; k < num_inner_elements; k++) {
203 output_values_flat(values_index, k) = component_values_flat(j, k);
204 }
205 }
206 }
207 return Status::OK();
208 }
209 } // namespace
210
211 template <typename VALUE_TYPE, typename SPLIT_TYPE>
212 class RaggedTensorFromVariantOp : public OpKernel {
213 public:
RaggedTensorFromVariantOp(OpKernelConstruction * context)214 explicit RaggedTensorFromVariantOp(OpKernelConstruction* context)
215 : OpKernel(context) {
216 OP_REQUIRES_OK(context, context->GetAttr("input_ragged_rank",
217 &input_ragged_rank_attr_));
218 OP_REQUIRES_OK(
219 context, context->GetAttr("output_ragged_rank", &output_ragged_rank_));
220 }
221
Compute(OpKernelContext * context)222 void Compute(OpKernelContext* context) override {
223 // Read input Tensor.
224 const Tensor& encoded_variant = context->input(0);
225 auto input_ragged_rank_ = input_ragged_rank_attr_;
226
227 if (input_ragged_rank_ == -1) { // Infer input_ragged_rank_.
228 input_ragged_rank_ = output_ragged_rank_ - encoded_variant.dims();
229 OP_REQUIRES(context, input_ragged_rank_ >= 0,
230 errors::InvalidArgument(
231 "Inferred input_ragged_rank (output_ragged_rank - "
232 "encoded_variant.dims()) must be >= 0, found "
233 "output_ragged_rank: ",
234 output_ragged_rank_,
235 ", encoded_variant.dims(): ", encoded_variant.dims(),
236 ", inferred input_ragged_rank: ", input_ragged_rank_));
237 }
238 OP_REQUIRES(
239 context,
240 output_ragged_rank_ == encoded_variant.dims() + input_ragged_rank_,
241 errors::InvalidArgument(
242 "output_ragged_rank must be equal to input_ragged_rank + "
243 "encoded_ragged.dims(); output_ragged_rank: ",
244 output_ragged_rank_, ", input_ragged_rank: ", input_ragged_rank_,
245 ", encoded_variant.dims(): ", encoded_variant.dims(), "."));
246
247 // Decode all variants.
248 const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
249 const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
250 std::vector<RaggedTensorVariant> decoded_components;
251 OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
252 encoded_variant, input_ragged_rank_,
253 value_dtype, split_dtype, &decoded_components));
254
255 // Corner case: input is a scalar.
256 if (encoded_variant.dims() == 0) {
257 ReturnRaggedTensor(context, decoded_components[0]);
258 return;
259 }
260
261 // Nested-Stack Ragged components into a batched RaggedTensor.
262 std::vector<int> encoded_dim_sizes(encoded_variant.dims(), 0);
263 for (int i = 0; i < encoded_variant.dims(); i++) {
264 encoded_dim_sizes[i] = encoded_variant.dim_size(i);
265 }
266 RaggedTensorVariant output_ragged;
267 OP_REQUIRES_OK(
268 context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
269 decoded_components, encoded_dim_sizes, input_ragged_rank_,
270 output_ragged_rank_, &output_ragged));
271
272 // Set output.
273 ReturnRaggedTensor(context, output_ragged);
274 }
275
276 private:
277 int input_ragged_rank_attr_;
278 int output_ragged_rank_;
279
ReturnRaggedTensor(OpKernelContext * context,const RaggedTensorVariant & ragged_tensor)280 void ReturnRaggedTensor(OpKernelContext* context,
281 const RaggedTensorVariant& ragged_tensor) {
282 int ragged_rank = ragged_tensor.ragged_rank();
283 OpOutputList splits_out;
284 OP_REQUIRES_OK(context,
285 context->output_list("output_nested_splits", &splits_out));
286 for (int i = 0; i < ragged_rank; i++) {
287 splits_out.set(i, ragged_tensor.splits(i));
288 }
289 context->set_output(ragged_rank, ragged_tensor.values());
290 }
291 };
292
293 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
294 REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant") \
295 .Device(DEVICE_CPU) \
296 .TypeConstraint<value_type>("Tvalues") \
297 .TypeConstraint<split_type>("Tsplits"), \
298 RaggedTensorFromVariantOp<value_type, split_type>);
299 #define REGISTER_KERNELS(value_type) \
300 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
301 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
302 TF_CALL_POD_TYPES(REGISTER_KERNELS);
303 TF_CALL_tstring(REGISTER_KERNELS);
304 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
305 TF_CALL_quint16(REGISTER_KERNELS);
306 TF_CALL_qint16(REGISTER_KERNELS);
307 #undef REGISTER_KERNELS
308 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
309 } // namespace tensorflow
310