1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/util/tensor_ops_util.h"
30
31 namespace tensorflow {
32 namespace {
33
34 template <typename VALUE_TYPE, typename SPLIT_TYPE>
UnbatchRaggedZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)35 Status UnbatchRaggedZerothDim(
36 const RaggedTensorVariant& batched_ragged,
37 std::vector<RaggedTensorVariant>* ragged_components) {
38 // Set up the component Ragged Tensors.
39 int ragged_rank = batched_ragged.ragged_rank();
40 auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
41 int num_components = batched_splits_top_vec.size() - 1;
42 int num_splits = ragged_rank - 1;
43 ragged_components->resize(num_components);
44 for (RaggedTensorVariant& ragged_component : *ragged_components) {
45 ragged_component.mutable_nested_splits()->reserve(num_splits);
46 }
47 const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
48 int num_inner_elems = batched_ragged.values().NumElements();
49 if (batched_ragged.values().dim_size(0) > 1) {
50 num_inner_elems /= batched_ragged.values().dim_size(0);
51 }
52 TensorShape values_shape = batched_ragged.values().shape();
53
54 // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
55 if (num_splits == 0) {
56 for (int i = 0; i < num_components; i++) {
57 int start = batched_splits_top_vec(i);
58 int limit = batched_splits_top_vec(i + 1);
59 int num_values = limit - start;
60 values_shape.set_dim(0, num_values);
61 (*ragged_components)[i].set_values(
62 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
63 auto ragged_component_values_flat =
64 (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
65 for (int j = 0; j < num_values * num_inner_elems; j++) {
66 ragged_component_values_flat(j) =
67 batched_flat(j + start * num_inner_elems);
68 }
69 }
70 return Status::OK();
71 }
72
73 // Unbatch nested splits.
74 std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
75 batched_splits_vec.reserve(ragged_rank);
76 for (int i = 0; i < ragged_rank; i++) {
77 batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
78 }
79 std::vector<int> index(num_splits, 1);
80 std::vector<int> ragged_component_values_size(num_components, 0);
81 for (int i = 0; i < num_components; i++) {
82 std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
83 ragged_component_splits_vec.reserve(num_splits);
84 int split_size = -1;
85 for (int j = 0; j < num_splits; j++) {
86 if (j == 0) {
87 split_size =
88 batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
89 } else {
90 // Update split size based on previous split.
91 int last_index = ragged_component_splits_vec[j - 1].size() - 1;
92 split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
93 }
94 (*ragged_components)[i].append_splits(
95 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
96 ragged_component_splits_vec.push_back(
97 (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
98 SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
99 ragged_component_splits_vec[j](0) = 0;
100 for (int k = 1; k < split_size; k++, index[j]++) {
101 ragged_component_splits_vec[j](k) =
102 batched_splits_vec[j + 1](index[j]) - last_split_value;
103 }
104 }
105 int last_split_size = ragged_component_splits_vec[num_splits - 1].size();
106 ragged_component_values_size[i] =
107 ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
108 }
109
110 // Unbatch values.
111 int value_index = 0;
112 for (int i = 0; i < num_components; i++) {
113 int num_values = ragged_component_values_size[i];
114 values_shape.set_dim(0, num_values);
115 (*ragged_components)[i].set_values(
116 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
117 auto ragged_component_values_flat =
118 (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
119 for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
120 ragged_component_values_flat(j) = batched_flat(value_index);
121 }
122 }
123
124 return Status::OK();
125 }
126 } // namespace
127
128 template <typename VALUE_TYPE, typename SPLIT_TYPE>
129 class RaggedTensorToVariantOp : public OpKernel {
130 public:
RaggedTensorToVariantOp(OpKernelConstruction * context)131 explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
132 : OpKernel(context) {
133 OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
134 }
135
Compute(OpKernelContext * context)136 void Compute(OpKernelContext* context) override {
137 // Read ragged_splits inputs.
138 OpInputList ragged_nested_splits_in;
139 OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
140 &ragged_nested_splits_in));
141 const int ragged_nested_splits_len = ragged_nested_splits_in.size();
142 RaggedTensorVariant batched_ragged_input;
143 // Read ragged_values input.
144 batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
145 batched_ragged_input.mutable_nested_splits()->reserve(
146 ragged_nested_splits_len);
147 for (int i = 0; i < ragged_nested_splits_len; i++) {
148 batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
149 }
150
151 if (!batched_input_) {
152 // Encode as a Scalar Variant Tensor.
153 Tensor* encoded_scalar;
154 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
155 &encoded_scalar));
156 encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
157 return;
158 }
159
160 // Checked here instead of at input in case batched_input_ is false
161 OP_REQUIRES(context, ragged_nested_splits_len > 0,
162 errors::InvalidArgument(
163 "rt_nested_splits must be a list of one or more, but "
164 "received rt_nested_splits of length 0."));
165
166 // Unbatch the Ragged Tensor and encode the components.
167 std::vector<RaggedTensorVariant> unbatched_ragged_input;
168 auto batched_splits_top_vec =
169 batched_ragged_input.splits(0).vec<SPLIT_TYPE>();
170 int num_components = batched_splits_top_vec.size() - 1;
171 OP_REQUIRES(context, num_components >= 0,
172 errors::Internal("Invalid split argument."));
173 OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
174 batched_ragged_input, &unbatched_ragged_input));
175
176 // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
177 Tensor* encoded_vector;
178 int output_size = unbatched_ragged_input.size();
179 OP_REQUIRES_OK(context,
180 context->allocate_output(0, TensorShape({output_size}),
181 &encoded_vector));
182 auto encoded_vector_t = encoded_vector->vec<Variant>();
183 for (int i = 0; i < output_size; i++) {
184 encoded_vector_t(i) = unbatched_ragged_input[i];
185 }
186 }
187
188 private:
189 bool batched_input_;
190 };
191
192 template <typename VALUE_TYPE, typename SPLIT_TYPE>
193 class RaggedTensorToVariantGradientOp : public OpKernel {
194 public:
195 using OpKernel::OpKernel;
196
Compute(OpKernelContext * context)197 void Compute(OpKernelContext* context) override {
198 // Read inputs.
199 Tensor encoded_variant = context->input(0);
200 Tensor row_splits = context->input(1);
201 auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
202 TensorShape dense_values_shape;
203 OP_REQUIRES_OK(context,
204 TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
205 &dense_values_shape));
206
207 const auto& flat_variants = encoded_variant.flat<Variant>();
208
209 // Get a Tensor containing the flat_values for each variant.
210 std::vector<Tensor> values;
211 for (int i = 0; i < flat_variants.size(); ++i) {
212 if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
213 values.push_back(encoded->values());
214 } else {
215 // Missing value: this happens if only some of the variant values
216 // generated by ragged_tensor_to_variant impacted the value that we're
217 // calculating the gradient for. In this case, we will see a
218 // default-constructed variant; so treat it as a zero tensor with the
219 // appropriate shape.
220 const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
221 int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
222 TensorShape zeros_shape = dense_values_shape;
223 zeros_shape.set_dim(0, piece_size);
224 Tensor zero(value_dtype, zeros_shape);
225 zero.flat<VALUE_TYPE>() =
226 zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
227 values.push_back(zero);
228 }
229 }
230
231 if (values.size() == 1) {
232 // Just one flat_value tensor: return as-is.
233 context->set_output(0, values[0]);
234 } else {
235 // Multiple flat_values tensors: concatenate them together.
236 using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
237 using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
238 std::vector<std::unique_ptr<ConstPiece>> pieces;
239 pieces.reserve(values.size());
240 for (const Tensor& t : values) {
241 pieces.emplace_back(
242 new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
243 }
244 Tensor* out = nullptr;
245 OP_REQUIRES_OK(context,
246 context->allocate_output(0, dense_values_shape, &out));
247 Piece out_flat =
248 out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
249 ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
250 }
251 }
252 };
253
254 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
255 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
256 .Device(DEVICE_CPU) \
257 .TypeConstraint<value_type>("Tvalues") \
258 .TypeConstraint<split_type>("Tsplits"), \
259 RaggedTensorToVariantOp<value_type, split_type>); \
260 REGISTER_KERNEL_BUILDER( \
261 Name("RaggedTensorToVariantGradient") \
262 .Device(DEVICE_CPU) \
263 .TypeConstraint<value_type>("Tvalues") \
264 .TypeConstraint<split_type>("Tsplits"), \
265 RaggedTensorToVariantGradientOp<value_type, split_type>);
266
267 #define REGISTER_KERNELS(value_type) \
268 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
269 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
270 TF_CALL_POD_TYPES(REGISTER_KERNELS);
271 TF_CALL_tstring(REGISTER_KERNELS);
272 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
273 TF_CALL_quint16(REGISTER_KERNELS);
274 TF_CALL_qint16(REGISTER_KERNELS);
275 #undef REGISTER_KERNELS
276 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
277 } // namespace tensorflow
278