1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <utility>
16 #include <vector>
17
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/register_types.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_op_registry.h"
25 #include "tensorflow/core/kernels/concat_lib.h"
26 #include "tensorflow/core/kernels/ragged_tensor_variant.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/util/tensor_ops_util.h"
30
31 namespace tensorflow {
32 namespace {
33
34 template <typename VALUE_TYPE, typename SPLIT_TYPE>
UnbatchRaggedZerothDim(const RaggedTensorVariant & batched_ragged,std::vector<RaggedTensorVariant> * ragged_components)35 Status UnbatchRaggedZerothDim(
36 const RaggedTensorVariant& batched_ragged,
37 std::vector<RaggedTensorVariant>* ragged_components) {
38 // Set up the component Ragged Tensors.
39 int ragged_rank = batched_ragged.ragged_rank();
40 auto batched_splits_top_vec = batched_ragged.splits(0).vec<SPLIT_TYPE>();
41 int num_components = batched_splits_top_vec.size() - 1;
42 int num_splits = ragged_rank - 1;
43 ragged_components->resize(num_components);
44 for (RaggedTensorVariant& ragged_component : *ragged_components) {
45 ragged_component.mutable_nested_splits()->reserve(num_splits);
46 }
47 const auto& batched_flat = batched_ragged.values().flat<VALUE_TYPE>();
48 int num_inner_elems = batched_ragged.values().NumElements();
49 if (batched_ragged.values().dim_size(0) > 1) {
50 num_inner_elems /= batched_ragged.values().dim_size(0);
51 }
52 TensorShape values_shape = batched_ragged.values().shape();
53
54 // Corner case: ragged_rank == 1, e.g. [[1, 2, 3], [4, 5]]
55 if (num_splits == 0) {
56 for (int i = 0; i < num_components; i++) {
57 int start = batched_splits_top_vec(i);
58 int limit = batched_splits_top_vec(i + 1);
59 int num_values = limit - start;
60 values_shape.set_dim(0, num_values);
61 (*ragged_components)[i].set_values(
62 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
63 auto ragged_component_values_flat =
64 (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
65 for (int j = 0; j < num_values * num_inner_elems; j++) {
66 ragged_component_values_flat(j) =
67 batched_flat(j + start * num_inner_elems);
68 }
69 }
70 return Status::OK();
71 }
72
73 // Unbatch nested splits.
74 std::vector<typename TTypes<SPLIT_TYPE>::ConstVec> batched_splits_vec;
75 batched_splits_vec.reserve(ragged_rank);
76 for (int i = 0; i < ragged_rank; i++) {
77 batched_splits_vec.push_back(batched_ragged.splits(i).vec<SPLIT_TYPE>());
78 }
79 std::vector<int> index(num_splits, 1);
80 std::vector<int> ragged_component_values_size(num_components, 0);
81 for (int i = 0; i < num_components; i++) {
82 std::vector<typename TTypes<SPLIT_TYPE>::Vec> ragged_component_splits_vec;
83 ragged_component_splits_vec.reserve(num_splits);
84 int split_size = -1;
85 for (int j = 0; j < num_splits; j++) {
86 if (j == 0) {
87 split_size =
88 batched_splits_top_vec(i + 1) - batched_splits_top_vec(i) + 1;
89 } else {
90 // Update split size based on previous split.
91 int last_index = ragged_component_splits_vec[j - 1].size() - 1;
92 split_size = ragged_component_splits_vec[j - 1](last_index) + 1;
93 }
94 (*ragged_components)[i].append_splits(
95 Tensor(DataTypeToEnum<SPLIT_TYPE>::value, TensorShape({split_size})));
96 ragged_component_splits_vec.push_back(
97 (*ragged_components)[i].mutable_splits(j)->vec<SPLIT_TYPE>());
98 SPLIT_TYPE last_split_value = batched_splits_vec[j + 1](index[j] - 1);
99 ragged_component_splits_vec[j](0) = 0;
100 for (int k = 1; k < split_size; k++, index[j]++) {
101 ragged_component_splits_vec[j](k) =
102 batched_splits_vec[j + 1](index[j]) - last_split_value;
103 }
104 }
105 int last_split_size = ragged_component_splits_vec[num_splits - 1].size();
106 ragged_component_values_size[i] =
107 ragged_component_splits_vec[num_splits - 1](last_split_size - 1);
108 }
109
110 // Unbatch values.
111 int value_index = 0;
112 for (int i = 0; i < num_components; i++) {
113 int num_values = ragged_component_values_size[i];
114 values_shape.set_dim(0, num_values);
115 (*ragged_components)[i].set_values(
116 Tensor(DataTypeToEnum<VALUE_TYPE>::value, values_shape));
117 auto ragged_component_values_flat =
118 (*ragged_components)[i].mutable_values()->flat<VALUE_TYPE>();
119 for (int j = 0; j < num_values * num_inner_elems; j++, value_index++) {
120 ragged_component_values_flat(j) = batched_flat(value_index);
121 }
122 }
123
124 return Status::OK();
125 }
126 } // namespace
127
128 template <typename VALUE_TYPE, typename SPLIT_TYPE>
129 class RaggedTensorToVariantOp : public OpKernel {
130 public:
RaggedTensorToVariantOp(OpKernelConstruction * context)131 explicit RaggedTensorToVariantOp(OpKernelConstruction* context)
132 : OpKernel(context) {
133 OP_REQUIRES_OK(context, context->GetAttr("batched_input", &batched_input_));
134 }
135
Compute(OpKernelContext * context)136 void Compute(OpKernelContext* context) override {
137 // Read ragged_splits inputs.
138 OpInputList ragged_nested_splits_in;
139 OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
140 &ragged_nested_splits_in));
141 const int ragged_nested_splits_len = ragged_nested_splits_in.size();
142 RaggedTensorVariant batched_ragged_input;
143 // Read ragged_values input.
144 batched_ragged_input.set_values(context->input(ragged_nested_splits_len));
145 batched_ragged_input.mutable_nested_splits()->reserve(
146 ragged_nested_splits_len);
147 for (int i = 0; i < ragged_nested_splits_len; i++) {
148 batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
149 }
150
151 if (!batched_input_) {
152 // Encode as a Scalar Variant Tensor.
153 Tensor* encoded_scalar;
154 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}),
155 &encoded_scalar));
156 encoded_scalar->scalar<Variant>()() = std::move(batched_ragged_input);
157 return;
158 }
159
160 // Unbatch the Ragged Tensor and encode the components.
161 std::vector<RaggedTensorVariant> unbatched_ragged_input;
162 OP_REQUIRES_OK(context, UnbatchRaggedZerothDim<VALUE_TYPE, SPLIT_TYPE>(
163 batched_ragged_input, &unbatched_ragged_input));
164
165 // Bundle the encoded scalar Variant Tensors into a rank-1 Variant Tensor.
166 Tensor* encoded_vector;
167 int output_size = unbatched_ragged_input.size();
168 OP_REQUIRES_OK(context,
169 context->allocate_output(0, TensorShape({output_size}),
170 &encoded_vector));
171 auto encoded_vector_t = encoded_vector->vec<Variant>();
172 for (int i = 0; i < output_size; i++) {
173 encoded_vector_t(i) = unbatched_ragged_input[i];
174 }
175 }
176
177 private:
178 bool batched_input_;
179 };
180
181 template <typename VALUE_TYPE, typename SPLIT_TYPE>
182 class RaggedTensorToVariantGradientOp : public OpKernel {
183 public:
184 using OpKernel::OpKernel;
185
Compute(OpKernelContext * context)186 void Compute(OpKernelContext* context) override {
187 // Read inputs.
188 Tensor encoded_variant = context->input(0);
189 Tensor row_splits = context->input(1);
190 auto flat_row_splits = row_splits.flat<SPLIT_TYPE>();
191 TensorShape dense_values_shape;
192 OP_REQUIRES_OK(context,
193 TensorShapeUtils::MakeShape(context->input(2).vec<int32>(),
194 &dense_values_shape));
195
196 const auto& flat_variants = encoded_variant.flat<Variant>();
197
198 // Get a Tensor containing the flat_values for each variant.
199 std::vector<Tensor> values;
200 for (int i = 0; i < flat_variants.size(); ++i) {
201 if (const auto* encoded = flat_variants(i).get<RaggedTensorVariant>()) {
202 values.push_back(encoded->values());
203 } else {
204 // Missing value: this happens if only some of the variant values
205 // generated by ragged_tensor_to_variant impacted the value that we're
206 // calculating the gradient for. In this case, we will see a
207 // default-constructed variant; so treat it as a zero tensor with the
208 // appropriate shape.
209 const auto value_dtype = DataTypeToEnum<VALUE_TYPE>::v();
210 int piece_size = flat_row_splits(i + 1) - flat_row_splits(i);
211 TensorShape zeros_shape = dense_values_shape;
212 zeros_shape.set_dim(0, piece_size);
213 Tensor zero(value_dtype, zeros_shape);
214 zero.flat<VALUE_TYPE>() =
215 zero.flat<VALUE_TYPE>().constant(VALUE_TYPE());
216 values.push_back(zero);
217 }
218 }
219
220 if (values.size() == 1) {
221 // Just one flat_value tensor: return as-is.
222 context->set_output(0, values[0]);
223 } else {
224 // Multiple flat_values tensors: concatenate them together.
225 using Piece = typename TTypes<VALUE_TYPE, 2>::Matrix;
226 using ConstPiece = typename TTypes<VALUE_TYPE, 2>::ConstMatrix;
227 std::vector<std::unique_ptr<ConstPiece>> pieces;
228 pieces.reserve(values.size());
229 for (const Tensor& t : values) {
230 pieces.emplace_back(
231 new ConstPiece(t.shaped<VALUE_TYPE, 2>({1, t.NumElements()})));
232 }
233 Tensor* out = nullptr;
234 OP_REQUIRES_OK(context,
235 context->allocate_output(0, dense_values_shape, &out));
236 Piece out_flat =
237 out->shaped<VALUE_TYPE, 2>({1, dense_values_shape.num_elements()});
238 ConcatCPU<VALUE_TYPE>(context->device(), pieces, &out_flat);
239 }
240 }
241 };
242
243 #define REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, split_type) \
244 REGISTER_KERNEL_BUILDER(Name("RaggedTensorToVariant") \
245 .Device(DEVICE_CPU) \
246 .TypeConstraint<value_type>("Tvalues") \
247 .TypeConstraint<split_type>("Tsplits"), \
248 RaggedTensorToVariantOp<value_type, split_type>); \
249 REGISTER_KERNEL_BUILDER( \
250 Name("RaggedTensorToVariantGradient") \
251 .Device(DEVICE_CPU) \
252 .TypeConstraint<value_type>("Tvalues") \
253 .TypeConstraint<split_type>("Tsplits"), \
254 RaggedTensorToVariantGradientOp<value_type, split_type>);
255
256 #define REGISTER_KERNELS(value_type) \
257 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int32) \
258 REGISTER_KERNELS_WITH_SPLIT_TYPE(value_type, int64)
259 TF_CALL_POD_TYPES(REGISTER_KERNELS);
260 TF_CALL_tstring(REGISTER_KERNELS);
261 TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
262 TF_CALL_quint16(REGISTER_KERNELS);
263 TF_CALL_qint16(REGISTER_KERNELS);
264 #undef REGISTER_KERNELS
265 #undef REGISTER_KERNELS_WITH_SPLIT_TYPE
266 } // namespace tensorflow
267