• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17 
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/util/tensor_ops_util.h"
30 
31 namespace tensorflow {
32 namespace {
33 
34 template <typename VALUE_TYPE, typename SPLIT_TYPE>
UnbatchRaggedZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)35 Status UnbatchRaggedZerothDim(
36     const RaggedTensorVariant& batched_ragged,
37     std::vector<RaggedTensorVariant>* ragged_components) {
38   // Set up the component Ragged Tensors.
39   int ragged_rank = batched_ragged.ragged_rank();
40   auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
41   int num_components = batched_splits_top_vec.size() - 1;
42   int num_splits = ragged_rank - 1;
43   ragged_components->resize(num_components);
44   for (RaggedTensorVariant& ragged_component : *ragged_components) {
45     ragged_component.mutable_nested_splits()->reserve(num_splits);
46   }
47   const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
48   int num_inner_elems = batched_ragged.values().NumElements();
49   if (batched_ragged.values().dim_size(0) > 1) {
50     num_inner_elems /= batched_ragged.values().dim_size(0);
51   }
52   TensorShape values_shape = batched_ragged.values().shape();
53 
54   // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
55   if (num_splits == 0) {
56     for (int i = 0; i < num_components; i++) {
57       int start = batched_splits_top_vec(i);
58       int limit = batched_splits_top_vec(i + 1);
59       int num_values = limit - start;
60       values_shape.set_dim(0, num_values);
61       (*ragged_components)[i].set_values(
62           Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
63       auto ragged_component_values_flat =
64           (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
65       for (int j = 0; j < num_values * num_inner_elems; j++) {
66         ragged_component_values_flat(j) =
67             batched_flat(j + start * num_inner_elems);
68       }
69     }
70     return Status::OK();
71   }
72 
73   // Unbatch nested splits.
74   std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
75   batched_splits_vec.reserve(ragged_rank);
76   for (int i = 0; i < ragged_rank; i++) {
77     batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
78   }
79   std::vector<int> index(num_splits, 1);
80   std::vector<int> ragged_component_values_size(num_components, 0);
81   for (int i = 0; i < num_components; i++) {
82     std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
83     ragged_component_splits_vec.reserve(num_splits);
84     int split_size = -1;
85     for (int j = 0; j < num_splits; j++) {
86       if (j == 0) {
87         split_size =
88             batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
89       } else {
90         // Update split size based on previous split.
91         int last_index = ragged_component_splits_vec[j - 1].size() - 1;
92         split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
93       }
94       (*ragged_components)[i].append_splits(
95           Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
96       ragged_component_splits_vec.push_back(
97           (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
98       SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
99       ragged_component_splits_vec[j](0) = 0;
100       for (int k = 1; k < split_size; k++, index[j]++) {
101         ragged_component_splits_vec[j](k) =
102             batched_splits_vec[j + 1](index[j]) - last_split_value;
103       }
104     }
105     int last_split_size = ragged_component_splits_vec[num_splits - 1].size();
106     ragged_component_values_size[i] =
107         ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
108   }
109 
110   // Unbatch values.
111   int value_index = 0;
112   for (int i = 0; i < num_components; i++) {
113     int num_values = ragged_component_values_size[i];
114     values_shape.set_dim(0, num_values);
115     (*ragged_components)[i].set_values(
116         Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
117     auto ragged_component_values_flat =
118         (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
119     for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
120       ragged_component_values_flat(j) = batched_flat(value_index);
121     }
122   }
123 
124   return Status::OK();
125 }
126 }  // namespace
127 
128 template <typename VALUE_TYPE, typename SPLIT_TYPE>
129 class RaggedTensorToVariantOp : public OpKernel {
130  public:
RaggedTensorToVariantOp(OpKernelConstruction * context)131   explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
132       : OpKernel(context) {
133     OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
134   }
135 
Compute(OpKernelContext * context)136   void Compute(OpKernelContext* context) override {
137     // Read ragged_splits inputs.
138     OpInputList ragged_nested_splits_in;
139     OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
140                                                 &ragged_nested_splits_in));
141     const int ragged_nested_splits_len = ragged_nested_splits_in.size();
142     RaggedTensorVariant batched_ragged_input;
143     // Read ragged_values input.
144     batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
145     batched_ragged_input.mutable_nested_splits()->reserve(
146         ragged_nested_splits_len);
147     for (int i = 0; i < ragged_nested_splits_len; i++) {
148       batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
149     }
150 
151     if (!batched_input_) {
152       // Encode as a Scalar Variant Tensor.
153       Tensor* encoded_scalar;
154       OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
155                                                        &encoded_scalar));
156       encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
157       return;
158     }
159 
160     // Unbatch the Ragged Tensor and encode the components.
161     std::vector<RaggedTensorVariant> unbatched_ragged_input;
162     OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
163                                 batched_ragged_input, &unbatched_ragged_input));
164 
165     // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
166     Tensor* encoded_vector;
167     int output_size = unbatched_ragged_input.size();
168     OP_REQUIRES_OK(context,
169                    context->allocate_output(0, TensorShape({output_size}),
170                                             &encoded_vector));
171     auto encoded_vector_t = encoded_vector->vec<Variant>();
172     for (int i = 0; i < output_size; i++) {
173       encoded_vector_t(i) = unbatched_ragged_input[i];
174     }
175   }
176 
177  private:
178   bool batched_input_;
179 };
180 
181 template <typename VALUE_TYPE, typename SPLIT_TYPE>
182 class RaggedTensorToVariantGradientOp : public OpKernel {
183  public:
184   using OpKernel::OpKernel;
185 
Compute(OpKernelContext * context)186   void Compute(OpKernelContext* context) override {
187     // Read inputs.
188     Tensor encoded_variant = context->input(0);
189     Tensor row_splits = context->input(1);
190     auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
191     TensorShape dense_values_shape;
192     OP_REQUIRES_OK(context,
193                    TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
194                                                &dense_values_shape));
195 
196     const auto& flat_variants = encoded_variant.flat<Variant>();
197 
198     // Get a Tensor containing the flat_values for each variant.
199     std::vector<Tensor> values;
200     for (int i = 0; i < flat_variants.size(); ++i) {
201       if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
202         values.push_back(encoded->values());
203       } else {
204         // Missing value: this happens if only some of the variant values
205         // generated by ragged_tensor_to_variant impacted the value that we're
206         // calculating the gradient for.  In this case, we will see a
207         // default-constructed variant; so treat it as a zero tensor with the
208         // appropriate shape.
209         const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
210         int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
211         TensorShape zeros_shape = dense_values_shape;
212         zeros_shape.set_dim(0, piece_size);
213         Tensor zero(value_dtype, zeros_shape);
214         zero.flat<VALUE_TYPE>() =
215             zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
216         values.push_back(zero);
217       }
218     }
219 
220     if (values.size() == 1) {
221       // Just one flat_value tensor: return as-is.
222       context->set_output(0, values[0]);
223     } else {
224       // Multiple flat_values tensors: concatenate them together.
225       using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
226       using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
227       std::vector<std::unique_ptr<ConstPiece>> pieces;
228       pieces.reserve(values.size());
229       for (const Tensor& t : values) {
230         pieces.emplace_back(
231             new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
232       }
233       Tensor* out = nullptr;
234       OP_REQUIRES_OK(context,
235                      context->allocate_output(0, dense_values_shape, &out));
236       Piece out_flat =
237           out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
238       ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
239     }
240   }
241 };
242 
243 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type)            \
244   REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant")                     \
245                               .Device(DEVICE_CPU)                           \
246                               .TypeConstraint<value_type>("Tvalues")        \
247                               .TypeConstraint<split_type>("Tsplits"),       \
248                           RaggedTensorToVariantOp<value_type, split_type>); \
249   REGISTER_KERNEL_BUILDER(                                                  \
250       Name("RaggedTensorToVariantGradient")                                 \
251           .Device(DEVICE_CPU)                                               \
252           .TypeConstraint<value_type>("Tvalues")                            \
253           .TypeConstraint<split_type>("Tsplits"),                           \
254       RaggedTensorToVariantGradientOp<value_type, split_type>);
255 
256 #define REGISTER_KERNELS(value_type)                  \
257   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
258   REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
259 TF_CALL_POD_TYPES(REGISTER_KERNELS);
260 TF_CALL_tstring(REGISTER_KERNELS);
261 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
262 TF_CALL_quint16(REGISTER_KERNELS);
263 TF_CALL_qint16(REGISTER_KERNELS);
264 #undef REGISTER_KERNELS
265 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
266 }  // namespace tensorflow
267