• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 
27 namespace tensorflow {
28 namespace {
29 
RaggedComponentsFromVariant(const Tensor & encoded_variant,int ragged_rank,DataType value_dtype,DataType split_dtype,std::vector<RaggedTensorVariant> * decoded_ragged)30 Status RaggedComponentsFromVariant(
31     const Tensor& encoded_variant, int ragged_rank, DataType value_dtype,
32     DataType split_dtype, std::vector<RaggedTensorVariant>* decoded_ragged) {
33   const auto& flat_variants = encoded_variant.flat<Variant>();
34   decoded_ragged->reserve(flat_variants.size());
35 
36   for (int i = 0; i < flat_variants.size(); i++) {
37     const auto& flat_variant = flat_variants(i);
38     const RaggedTensorVariant* decoded =
39         flat_variant.get<RaggedTensorVariant>();
40     if (decoded == nullptr) {
41       return errors::InvalidArgument(
42           "Input Variant element at index ", i,
43           " doesn't hold a RaggedTensorVariant: ", flat_variant.DebugString());
44     }
45     decoded_ragged->push_back(*decoded);
46     decoded = &decoded_ragged->back();
47     // Check ragged rank & types
48     if (decoded->ragged_rank() != ragged_rank) {
49       return errors::InvalidArgument(
50           "Encoded input RaggedTensorVariant has ragged_rank=",
51           decoded->ragged_rank(), ".  Expected ragged_rank=", ragged_rank, ".");
52     }
53     if (decoded->values().dtype() != value_dtype) {
54       return errors::InvalidArgument(
55           "Expected values Tensor dtype: ", DataTypeString(value_dtype),
56           ", found: ", DataTypeString(decoded->values().dtype()));
57     }
58     if (decoded->values().dims() < 1) {
59       return errors::InvalidArgument(
60           "Ragged values must have rank >= 1; encoded scalar element at index ",
61           i, " has values Tensor: ", decoded->values().DebugString());
62     }
63     for (const auto& splits : decoded->nested_splits()) {
64       if (splits.dtype() != split_dtype) {
65         return errors::InvalidArgument(
66             "Expected row_splits Tensor dtype: ", DataTypeString(split_dtype),
67             ", found: ", DataTypeString(splits.dtype()));
68       }
69       if (splits.dims() != 1) {
70         return errors::InvalidArgument(
71             "Ragged splits must have rank 1; encoded scalar element at index ",
72             i, " has splits Tensor ", splits.DebugString());
73       }
74     }
75   }
76   return Status::OK();
77 }
78 
79 template <typename VALUE_TYPE, typename SPLIT_TYPE>
NestedStackRaggedTensors(const std::vector<RaggedTensorVariant> & ragged_components,const std::vector<int> & nested_dim_sizes,const int input_ragged_rank,const int output_ragged_rank,RaggedTensorVariant * output_ragged)80 Status NestedStackRaggedTensors(
81     const std::vector<RaggedTensorVariant>& ragged_components,
82     const std::vector<int>& nested_dim_sizes, const int input_ragged_rank,
83     const int output_ragged_rank, RaggedTensorVariant* output_ragged) {
84   output_ragged->mutable_nested_splits()->reserve(output_ragged_rank);
85   const int dims = nested_dim_sizes.size();
86 
87   // Populate first `dims - 1` splits.
88   for (int i = 0; i < dims - 1; i++) {
89     int dims_splits_size = nested_dim_sizes[i] + 1;
90     output_ragged->append_splits(Tensor(DataTypeToEnum<SPLIT_TYPE>::value,
91                                         TensorShape({dims_splits_size})));
92     auto splits_vec = output_ragged->mutable_splits(i)->vec<SPLIT_TYPE>();
93     int split_diff = nested_dim_sizes[i + 1];
94     for (int j = 0; j < dims_splits_size; j++) {
95       splits_vec(j) = j * split_diff;
96     }
97   }
98 
99   // Populate `dims`-th split.
100   int splits_size = ragged_components.size() + 1;
101   output_ragged->append_splits(
102       Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({splits_size})));
103   auto dims_splits_vec =
104       output_ragged->mutable_splits(dims - 1)->vec<SPLIT_TYPE>();
105   dims_splits_vec(0) = 0;
106   for (int i = 0; i < ragged_components.size(); i++) {
107     int split_val = ragged_components[i].values().shape().dim_size(0);
108     if (input_ragged_rank != 0 && ragged_components[i].ragged_rank() > 0) {
109       split_val = ragged_components[i].splits(0).NumElements() - 1;
110     }
111     dims_splits_vec(i + 1) = dims_splits_vec(i) + split_val;
112   }
113 
114   // Populate last `input_ragged_rank` splits.
115   for (int i = 0; i < input_ragged_rank; i++) {
116     int split_index = dims + i;
117     int split_size = 1;
118     for (int j = 0; j < ragged_components.size(); j++) {
119       if (!ragged_components[j].nested_splits().empty()) {
120         split_size += ragged_components[j].splits(i).NumElements() - 1;
121       }
122     }
123     output_ragged->append_splits(
124         Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
125     auto splits_vec =
126         output_ragged->mutable_splits(split_index)->vec<SPLIT_TYPE>();
127     splits_vec(0) = 0;
128     SPLIT_TYPE last_split_value = 0;
129     int index = 1;
130     for (int j = 0; j < ragged_components.size(); j++) {
131       if (ragged_components[j].nested_splits().empty()) {
132         // Corner case: empty row. e.g [ [[x], [x]], [] ]
133         continue;
134       }
135       auto component_splits_vec =
136           ragged_components[j].splits(i).vec<SPLIT_TYPE>();
137       for (int k = 1; k < component_splits_vec.size(); k++, index++) {
138         splits_vec(index) = component_splits_vec(k) + last_split_value;
139       }
140       last_split_value = splits_vec(index - 1);
141     }
142   }
143 
144   // If the variant tensor input is empty, then we have no way to determine
145   // the correct shape for the dense_values.  (It must have rank>=1, and its
146   // outer dimension must be 0, but we don't know its shape beyond that.)
147   // For now, we just use a shape of `[0]` in this case.
148   // TODO(edloper): Update this op with an attribute containing information
149   // about dense_values shape.  If it's `None`, then we'll probably still have
150   // to use shape=[0] here, but if we have more info, then we can use it.
151   // E.g., in map_fn, we may have shape info from the RaggedTensorSpec.
152   TensorShape component_values_shape;
153   if (ragged_components.empty()) {
154     component_values_shape = TensorShape({0});
155   } else {
156     component_values_shape = ragged_components[0].values().shape();
157   }
158 
159   // Populate values.
160   int values_size = component_values_shape.dim_size(0);
161   for (int i = 1; i < ragged_components.size(); i++) {
162     if (ragged_components[i].values().dims() != component_values_shape.dims()) {
163       return errors::InvalidArgument(
164           "Rank of values must match for all "
165           "components; values shape at index 0: ",
166           component_values_shape.DebugString(), ", values shape at index ", i,
167           ": ", ragged_components[i].values().shape().DebugString());
168     }
169     values_size += ragged_components[i].values().shape().dim_size(0);
170   }
171   component_values_shape.set_dim(0, values_size);
172   output_ragged->set_values(
173       Tensor(DataTypeToEnum<VALUE_TYPE>::value, component_values_shape));
174   auto output_values_flat =
175       output_ragged->mutable_values()->flat_outer_dims<VALUE_TYPE, 2>();
176   int values_index = 0;
177 
178   TensorShape expected_value_shape = component_values_shape;
179   expected_value_shape.RemoveDim(0);
180 
181   for (int i = 0; i < ragged_components.size(); i++) {
182     // Check that the flat_values tensor shape is compatible.
183     TensorShape value_shape = ragged_components[i].values().shape();
184     value_shape.RemoveDim(0);
185     if (value_shape != expected_value_shape) {
186       return errors::InvalidArgument(
187           "All flat_values must have compatible shapes.  Shape at index 0: ",
188           expected_value_shape, ".  Shape at index ", i, ": ", value_shape,
189           ".  If you are using tf.map_fn, then you may need to specify an "
190           "explicit fn_output_signature with appropriate ragged_rank, and/or "
191           "convert output tensors to RaggedTensors.");
192     }
193 
194     auto component_values_flat =
195         ragged_components[i].values().flat_outer_dims<VALUE_TYPE, 2>();
196     int num_inner_elements = ragged_components[i].values().NumElements();
197     if (ragged_components[i].values().dim_size(0) > 0) {
198       num_inner_elements /= ragged_components[i].values().dim_size(0);
199     }
200     for (int j = 0; j < ragged_components[i].values().dim_size(0);
201          j++, values_index++) {
202       for (int k = 0; k < num_inner_elements; k++) {
203         output_values_flat(values_index, k) = component_values_flat(j, k);
204       }
205     }
206   }
207   return Status::OK();
208 }
209 }  // namespace
210 
211 template <typename VALUE_TYPE, typename SPLIT_TYPE>
212 class RaggedTensorFromVariantOp : public OpKernel {
213  public:
RaggedTensorFromVariantOp(OpKernelConstruction * context)214   explicit RaggedTensorFromVariantOp(OpKernelConstruction* context)
215       : OpKernel(context) {
216     OP_REQUIRES_OK(context, context->GetAttr("input_ragged_rank",
217                                              &input_ragged_rank_attr_));
218     OP_REQUIRES_OK(
219         context, context->GetAttr("output_ragged_rank", &output_ragged_rank_));
220   }
221 
Compute(OpKernelContext * context)222   void Compute(OpKernelContext* context) override {
223     // Read input Tensor.
224     const Tensor& encoded_variant = context->input(0);
225     auto input_ragged_rank_ = input_ragged_rank_attr_;
226 
227     if (input_ragged_rank_ == -1) {  // Infer input_ragged_rank_.
228       input_ragged_rank_ = output_ragged_rank_ - encoded_variant.dims();
229       OP_REQUIRES(context, input_ragged_rank_ >= 0,
230                   errors::InvalidArgument(
231                       "Inferred input_ragged_rank (output_ragged_rank - "
232                       "encoded_variant.dims()) must be >= 0, found "
233                       "output_ragged_rank: ",
234                       output_ragged_rank_,
235                       ", encoded_variant.dims(): ", encoded_variant.dims(),
236                       ", inferred input_ragged_rank: ", input_ragged_rank_));
237     }
238     OP_REQUIRES(
239         context,
240         output_ragged_rank_ == encoded_variant.dims() + input_ragged_rank_,
241         errors::InvalidArgument(
242             "output_ragged_rank must be equal to input_ragged_rank + "
243             "encoded_ragged.dims(); output_ragged_rank: ",
244             output_ragged_rank_, ", input_ragged_rank: ", input_ragged_rank_,
245             ", encoded_variant.dims(): ", encoded_variant.dims(), "."));
246 
247     // Decode all variants.
248     const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
249     const auto split_dtype = DataTypeToEnum<SPLIT_TYPE>::v();
250     std::vector<RaggedTensorVariant> decoded_components;
251     OP_REQUIRES_OK(context, RaggedComponentsFromVariant(
252                                 encoded_variant, input_ragged_rank_,
253                                 value_dtype, split_dtype, &decoded_components));
254 
255     // Corner case: input is a scalar.
256     if (encoded_variant.dims() == 0) {
257       ReturnRaggedTensor(context, decoded_components[0]);
258       return;
259     }
260 
261     // Nested-Stack Ragged components into a batched RaggedTensor.
262     std::vector<int> encoded_dim_sizes(encoded_variant.dims(), 0);
263     for (int i = 0; i < encoded_variant.dims(); i++) {
264       encoded_dim_sizes[i] = encoded_variant.dim_size(i);
265     }
266     RaggedTensorVariant output_ragged;
267     OP_REQUIRES_OK(
268         context, NestedStackRaggedTensors<VALUE_TYPE, SPLIT_TYPE>(
269                      decoded_components, encoded_dim_sizes, input_ragged_rank_,
270                      output_ragged_rank_, &output_ragged));
271 
272     // Set output.
273     ReturnRaggedTensor(context, output_ragged);
274   }
275 
276  private:
277   int input_ragged_rank_attr_;
278   int output_ragged_rank_;
279 
ReturnRaggedTensor(OpKernelContext * context,const RaggedTensorVariant & ragged_tensor)280   void ReturnRaggedTensor(OpKernelContext* context,
281                           const RaggedTensorVariant& ragged_tensor) {
282     int ragged_rank = ragged_tensor.ragged_rank();
283     OpOutputList splits_out;
284     OP_REQUIRES_OK(context,
285                    context->output_list("output_nested_splits", &splits_out));
286     for (int i = 0; i < ragged_rank; i++) {
287       splits_out.set(i, ragged_tensor.splits(i));
288     }
289     context->set_output(ragged_rank, ragged_tensor.values());
290   }
291 };
292 
293 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)      \
294   REGISTER_KERNEL_BUILDER(Name("RaggedTensorFromVariant")             \
295                               .Device(DEVICE_CPU)                     \
296                               .TypeConstraint<value_type>("Tvalues")  \
297                               .TypeConstraint<split_type>("Tsplits"), \
298                           RaggedTensorFromVariantOp<value_type, split_type>);
299 #define REGISTER_KERNELS(value_type)                  \
300   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
301   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
302 TF_CALL_POD_TYPES(REGISTER_KERNELS);
303 TF_CALL_tstring(REGISTER_KERNELS);
304 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
305 TF_CALL_quint16(REGISTER_KERNELS);
306 TF_CALL_qint16(REGISTER_KERNELS);
307 #undef REGISTER_KERNELS
308 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
309 }  // namespace tensorflow
310