• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/util/tensor_ops_util.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 template <typename VALUE_TYPE, typename SPLIT_TYPE>
UnbatchRaggedZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)35 Status UnbatchRaggedZerothDim(
36     const RaggedTensorVariant& batched_ragged,
37     std::vector<RaggedTensorVariant>* ragged_components) {
38   // Set up the component Ragged Tensors.
39   int ragged_rank = batched_ragged.ragged_rank();
40   auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
41   int num_components = batched_splits_top_vec.size() - 1;
42   int num_splits = ragged_rank - 1;
43   ragged_components->resize(num_components);
44   for (RaggedTensorVariant& ragged_component : *ragged_components) {
45     ragged_component.mutable_nested_splits()->reserve(num_splits);
46   }
47   const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
48   int num_inner_elems = batched_ragged.values().NumElements();
49   if (batched_ragged.values().dim_size(0) > 1) {
50     num_inner_elems /= batched_ragged.values().dim_size(0);
51   }
52   TensorShape values_shape = batched_ragged.values().shape();
53 
54   // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
55   if (num_splits == 0) {
56     for (int i = 0; i < num_components; i++) {
57       int start = batched_splits_top_vec(i);
58       int limit = batched_splits_top_vec(i + 1);
59       int num_values = limit - start;
60       values_shape.set_dim(0, num_values);
61       (*ragged_components)[i].set_values(
62           Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
63       auto ragged_component_values_flat =
64           (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
65       for (int j = 0; j < num_values * num_inner_elems; j++) {
66         ragged_component_values_flat(j) =
67             batched_flat(j + start * num_inner_elems);
68       }
69     }
70     return Status::OK();
71   }
72 
73   // Unbatch nested splits.
74   std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
75   batched_splits_vec.reserve(ragged_rank);
76   for (int i = 0; i < ragged_rank; i++) {
77     batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
78   }
79   std::vector<int> index(num_splits, 1);
80   std::vector<int> ragged_component_values_size(num_components, 0);
81   for (int i = 0; i < num_components; i++) {
82     std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
83     ragged_component_splits_vec.reserve(num_splits);
84     int split_size = -1;
85     for (int j = 0; j < num_splits; j++) {
86       if (j == 0) {
87         split_size =
88             batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
89       } else {
90         // Update split size based on previous split.
91         int last_index = ragged_component_splits_vec[j - 1].size() - 1;
92         split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
93       }
94       (*ragged_components)[i].append_splits(
95           Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
96       ragged_component_splits_vec.push_back(
97           (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
98       SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
99       ragged_component_splits_vec[j](0) = 0;
100       for (int k = 1; k < split_size; k++, index[j]++) {
101         ragged_component_splits_vec[j](k) =
102             batched_splits_vec[j + 1](index[j]) - last_split_value;
103       }
104     }
105     int last_split_size = ragged_component_splits_vec[num_splits - 1].size();
106     ragged_component_values_size[i] =
107         ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
108   }
109 
110   // Unbatch values.
111   int value_index = 0;
112   for (int i = 0; i < num_components; i++) {
113     int num_values = ragged_component_values_size[i];
114     values_shape.set_dim(0, num_values);
115     (*ragged_components)[i].set_values(
116         Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
117     auto ragged_component_values_flat =
118         (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
119     for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
120       ragged_component_values_flat(j) = batched_flat(value_index);
121     }
122   }
123 
124   return Status::OK();
125 }
126 }  // namespace
127 
128 template <typename VALUE_TYPE, typename SPLIT_TYPE>
129 class RaggedTensorToVariantOp : public OpKernel {
130  public:
RaggedTensorToVariantOp(OpKernelConstruction * context)131   explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
132       : OpKernel(context) {
133     OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
134   }
135 
Compute(OpKernelContext * context)136   void Compute(OpKernelContext* context) override {
137     // Read ragged_splits inputs.
138     OpInputList ragged_nested_splits_in;
139     OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
140                                                 &ragged_nested_splits_in));
141     const int ragged_nested_splits_len = ragged_nested_splits_in.size();
142     RaggedTensorVariant batched_ragged_input;
143     // Read ragged_values input.
144     batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
145     batched_ragged_input.mutable_nested_splits()->reserve(
146         ragged_nested_splits_len);
147     for (int i = 0; i < ragged_nested_splits_len; i++) {
148       batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
149     }
150 
151     if (!batched_input_) {
152       // Encode as a Scalar Variant Tensor.
153       Tensor* encoded_scalar;
154       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
155                                                        &encoded_scalar));
156       encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
157       return;
158     }
159 
160     // Checked here instead of at input in case batched_input_ is false
161     OP_REQUIRES(context, ragged_nested_splits_len > 0,
162                 errors::InvalidArgument(
163                     "rt_nested_splits must be a list of one or more, but "
164                     "received rt_nested_splits of length 0."));
165 
166     // Unbatch the Ragged Tensor and encode the components.
167     std::vector<RaggedTensorVariant> unbatched_ragged_input;
168     auto batched_splits_top_vec =
169         batched_ragged_input.splits(0).vec<SPLIT_TYPE>();
170     int num_components = batched_splits_top_vec.size() - 1;
171     OP_REQUIRES(context, num_components >= 0,
172                 errors::Internal("Invalid split argument."));
173     OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
174                                 batched_ragged_input, &unbatched_ragged_input));
175 
176     // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
177     Tensor* encoded_vector;
178     int output_size = unbatched_ragged_input.size();
179     OP_REQUIRES_OK(context,
180                    context->allocate_output(0, TensorShape({output_size}),
181                                             &encoded_vector));
182     auto encoded_vector_t = encoded_vector->vec<Variant>();
183     for (int i = 0; i < output_size; i++) {
184       encoded_vector_t(i) = unbatched_ragged_input[i];
185     }
186   }
187 
188  private:
189   bool batched_input_;
190 };
191 
192 template <typename VALUE_TYPE, typename SPLIT_TYPE>
193 class RaggedTensorToVariantGradientOp : public OpKernel {
194  public:
195   using OpKernel::OpKernel;
196 
Compute(OpKernelContext * context)197   void Compute(OpKernelContext* context) override {
198     // Read inputs.
199     Tensor encoded_variant = context->input(0);
200     Tensor row_splits = context->input(1);
201     auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
202     TensorShape dense_values_shape;
203     OP_REQUIRES_OK(context,
204                    TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
205                                                &dense_values_shape));
206 
207     const auto& flat_variants = encoded_variant.flat<Variant>();
208 
209     // Get a Tensor containing the flat_values for each variant.
210     std::vector<Tensor> values;
211     for (int i = 0; i < flat_variants.size(); ++i) {
212       if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
213         values.push_back(encoded->values());
214       } else {
215         // Missing value: this happens if only some of the variant values
216         // generated by ragged_tensor_to_variant impacted the value that we're
217         // calculating the gradient for.  In this case, we will see a
218         // default-constructed variant; so treat it as a zero tensor with the
219         // appropriate shape.
220         const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
221         int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
222         TensorShape zeros_shape = dense_values_shape;
223         zeros_shape.set_dim(0, piece_size);
224         Tensor zero(value_dtype, zeros_shape);
225         zero.flat<VALUE_TYPE>() =
226             zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
227         values.push_back(zero);
228       }
229     }
230 
231     if (values.size() == 1) {
232       // Just one flat_value tensor: return as-is.
233       context->set_output(0, values[0]);
234     } else {
235       // Multiple flat_values tensors: concatenate them together.
236       using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
237       using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
238       std::vector<std::unique_ptr<ConstPiece>> pieces;
239       pieces.reserve(values.size());
240       for (const Tensor& t : values) {
241         pieces.emplace_back(
242             new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
243       }
244       Tensor* out = nullptr;
245       OP_REQUIRES_OK(context,
246                      context->allocate_output(0, dense_values_shape, &out));
247       Piece out_flat =
248           out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
249       ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
250     }
251   }
252 };
253 
254 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)            \
255   REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant")                     \
256                               .Device(DEVICE_CPU)                           \
257                               .TypeConstraint<value_type>("Tvalues")        \
258                               .TypeConstraint<split_type>("Tsplits"),       \
259                           RaggedTensorToVariantOp<value_type, split_type>); \
260   REGISTER_KERNEL_BUILDER(                                                  \
261       Name("RaggedTensorToVariantGradient")                                 \
262           .Device(DEVICE_CPU)                                               \
263           .TypeConstraint<value_type>("Tvalues")                            \
264           .TypeConstraint<split_type>("Tsplits"),                           \
265       RaggedTensorToVariantGradientOp<value_type, split_type>);
266 
267 #define REGISTER_KERNELS(value_type)                  \
268   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
269   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
270 TF_CALL_POD_TYPES(REGISTER_KERNELS);
271 TF_CALL_tstring(REGISTER_KERNELS);
272 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
273 TF_CALL_quint16(REGISTER_KERNELS);
274 TF_CALL_qint16(REGISTER_KERNELS);
275 #undef REGISTER_KERNELS
276 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
277 }  // namespace tensorflow
278