1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 #include "tensorflow/core/util/ragged_to_dense_util.h"
19
20 namespace tensorflow {
21
22 using errors::InvalidArgument;
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
ValidateRowPartitionTypesAndShapes(const std::vector<RowPartitionType> & row_partition_types,InferenceContext * c)28 tensorflow::Status ValidateRowPartitionTypesAndShapes(
29 const std::vector<RowPartitionType>& row_partition_types,
30 InferenceContext* c) {
31 // Note: the allowed types may be extended in the future.
32 for (RowPartitionType row_partition_type : row_partition_types) {
33 switch (row_partition_type) {
34 case RowPartitionType::FIRST_DIM_SIZE:
35 case RowPartitionType::VALUE_ROWIDS:
36 case RowPartitionType::ROW_SPLITS:
37 break;
38 default:
39 return InvalidArgument("Unsupported partition type: ",
40 RowPartitionTypeToString(row_partition_type));
41 }
42 }
43
44 if (row_partition_types.empty()) {
45 return InvalidArgument("Partition info types should not be empty");
46 }
47 for (int i = 1; i < row_partition_types.size(); ++i) {
48 if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
49 return InvalidArgument("FIRST_DIM_SIZE must be first");
50 }
51 }
52 if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE &&
53 (row_partition_types.size() < 2 ||
54 row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) {
55 return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS");
56 }
57 if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) {
58 return InvalidArgument("VALUE_ROWIDS cannot be first");
59 }
60
61 int num_row_partition_tensors;
62 TF_RETURN_IF_ERROR(
63 c->GetAttr("num_row_partition_tensors", &num_row_partition_tensors));
64 if (num_row_partition_tensors != row_partition_types.size()) {
65 return InvalidArgument(
66 "Number of row partition tensors (", num_row_partition_tensors,
67 ") does not equal the number of row partition types(",
68 row_partition_types.size(), ").");
69 }
70
71 for (int i = 0; i < num_row_partition_tensors; ++i) {
72 TensorShapeProto partition_shape;
73 c->ShapeHandleToProto(c->input(3 + i), &partition_shape);
74 if (partition_shape.unknown_rank()) {
75 continue;
76 }
77 if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
78 if (partition_shape.dim_size() != 0) {
79 return InvalidArgument("FIRST_DIM_SIZE must be a scalar.");
80 }
81 } else {
82 if (partition_shape.dim_size() != 1) {
83 return InvalidArgument("Row partition must be a vector.");
84 }
85 }
86 }
87 return tensorflow::Status::OK();
88 }
89
90 } // namespace
91
92 Status RaggedTensorToSparseShapeFn(InferenceContext* c);
93 Status RaggedTensorToVariantShapeFn(InferenceContext* c);
94 Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
95 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
96 Status RaggedTensorToTensorShapeFn(InferenceContext* c);
97
98 //==============================================================================
99 // Registered Ops
100 //==============================================================================
101
102 REGISTER_OP("RaggedTensorToSparse")
103 .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
104 .Input("rt_dense_values: T")
105 .Output("sparse_indices: int64")
106 .Output("sparse_values: T")
107 .Output("sparse_dense_shape: int64")
108 .Attr("RAGGED_RANK: int >= 1")
109 .Attr("T: type")
110 .Attr("Tsplits: {int32, int64} = DT_INT64")
111 .SetShapeFn(RaggedTensorToSparseShapeFn);
112
113 REGISTER_OP("RaggedTensorToVariant")
114 .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
115 .Input("rt_dense_values: Tvalues")
116 .Output("encoded_ragged: variant")
117 .Attr("RAGGED_RANK: int >= 0")
118 .Attr("Tvalues: type")
119 .Attr("Tsplits: {int32, int64} = DT_INT64")
120 .Attr("batched_input: bool")
121 .SetShapeFn(RaggedTensorToVariantShapeFn);
122
123 REGISTER_OP("RaggedTensorFromVariant")
124 .Input("encoded_ragged: variant")
125 .Output("output_nested_splits: output_ragged_rank * Tsplits")
126 .Output("output_dense_values: Tvalues")
127 .Attr("input_ragged_rank: int >= -1")
128 .Attr("output_ragged_rank: int >= 0")
129 .Attr("Tvalues: type")
130 .Attr("Tsplits: {int32, int64} = DT_INT64")
131 .SetShapeFn(RaggedTensorFromVariantShapeFn);
132
133 REGISTER_OP("RaggedTensorToVariantGradient")
134 .Input("encoded_ragged_grad: variant")
135 .Input("row_splits: Tsplits")
136 .Input("dense_values_shape: int32")
137 .Output("dense_values_grad: Tvalues")
138 .Attr("Tvalues: type")
139 .Attr("Tsplits: {int32, int64} = DT_INT64")
140 .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
141
142 REGISTER_OP("RaggedTensorToTensor")
143 .Attr("T: type")
144 .Attr("Tindex: {int64, int32}")
145 .Attr("Tshape: {int64, int32}")
146 .Attr("num_row_partition_tensors: int")
147 .Attr("row_partition_types: list(string)")
148 .Input("shape: Tshape")
149 .Input("values: T")
150 .Input("default_value: T")
151 .Input("row_partition_tensors: num_row_partition_tensors * Tindex")
152 .Output("result: T")
153 .SetShapeFn(RaggedTensorToTensorShapeFn);
154
155 //==============================================================================
156 // Shape Functions
157 //==============================================================================
158
RaggedTensorToSparseShapeFn(InferenceContext * c)159 Status RaggedTensorToSparseShapeFn(InferenceContext* c) {
160 int64 num_splits;
161 TF_RETURN_IF_ERROR(c->GetAttr<int64>("RAGGED_RANK", &num_splits));
162 // TODO(b/112274756): Allow ragged_rank to be 0.
163 if (num_splits < 1) {
164 return errors::InvalidArgument("Requires RAGGED_RANK>0");
165 }
166 ShapeHandle rt_dense_values = c->input(num_splits);
167 TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
168
169 // Check that all rt_nested_splits have rank 1.
170 for (int64 i = 0; i < num_splits; ++i) {
171 ShapeHandle splits = c->input(i);
172 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
173 }
174
175 DimensionHandle dense_dims =
176 c->RankKnown(rt_dense_values)
177 ? c->MakeDim(c->Rank(rt_dense_values) + num_splits)
178 : c->UnknownDim();
179 DimensionHandle num_values = c->NumElements(rt_dense_values);
180
181 c->set_output(0, c->Matrix(num_values, dense_dims)); // indices
182 c->set_output(1, c->Vector(num_values)); // values
183 c->set_output(2, c->Vector(dense_dims)); // dense_shape
184
185 return Status::OK();
186 }
187
RaggedTensorToVariantShapeFn(InferenceContext * c)188 Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
189 int64 num_splits;
190 TF_RETURN_IF_ERROR(c->GetAttr<int64>("RAGGED_RANK", &num_splits));
191 bool batched;
192 TF_RETURN_IF_ERROR(c->GetAttr<bool>("batched_input", &batched));
193 shape_inference::ShapeHandle rt_dense_values = c->input(num_splits);
194 TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
195 for (int64 i = 0; i < num_splits; ++i) {
196 shape_inference::ShapeHandle splits = c->input(i);
197 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
198 }
199 if (batched) {
200 auto num_first_splits = c->Dim(c->input(0), 0);
201 shape_inference::DimensionHandle num_rows;
202 TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows));
203 c->set_output(0, c->Vector(num_rows));
204 } else {
205 c->set_output(0, c->Scalar());
206 }
207 if (batched && num_splits == 0) {
208 return errors::InvalidArgument(
209 "ragged_rank=0 is not currently supported when batched_input=true.");
210 }
211 return Status::OK();
212 }
213
RaggedTensorToVariantGradientShapeFn(InferenceContext * c)214 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
215 ShapeHandle shape;
216 TF_RETURN_IF_ERROR(
217 c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
218 c->set_output(0, shape);
219 return Status::OK();
220 }
221
RaggedTensorFromVariantShapeFn(InferenceContext * c)222 Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
223 int64 input_ragged_rank;
224 TF_RETURN_IF_ERROR(
225 c->GetAttr<int64>("input_ragged_rank", &input_ragged_rank));
226 int64 output_ragged_rank;
227 TF_RETURN_IF_ERROR(
228 c->GetAttr<int64>("output_ragged_rank", &output_ragged_rank));
229 shape_inference::ShapeHandle encoded_ragged = c->input(0);
230 if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) {
231 shape_inference::ShapeHandle unused;
232 TF_RETURN_IF_ERROR(c->WithRank(
233 encoded_ragged, output_ragged_rank - input_ragged_rank, &unused));
234 }
235 for (int64 i = 0; i < output_ragged_rank; i++) {
236 c->set_output(i, c->UnknownShapeOfRank(1));
237 }
238 c->set_output(output_ragged_rank, c->UnknownShape());
239 return Status::OK();
240 }
241
RaggedTensorToTensorShapeFn(InferenceContext * c)242 tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) {
243 TensorShapeProto shape;
244 {
245 ShapeHandle shape_handle;
246 TF_RETURN_IF_ERROR(
247 c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle));
248 c->ShapeHandleToProto(shape_handle, &shape);
249 }
250
251 std::vector<RowPartitionType> row_partition_types;
252 TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types));
253 int ragged_rank = GetRaggedRank(row_partition_types);
254 TF_RETURN_IF_ERROR(
255 ValidateRowPartitionTypesAndShapes(row_partition_types, c));
256
257 TensorShapeProto value_shape;
258 c->ShapeHandleToProto(c->input(1), &value_shape);
259
260 TensorShapeProto default_value_shape;
261 c->ShapeHandleToProto(c->input(2), &default_value_shape);
262
263 TF_RETURN_IF_ERROR(
264 ValidateDefaultValueShape(default_value_shape, value_shape));
265
266 // TODO(martinz): Theoretically, we could check the first dimension of
267 // value_shape against the first dimension of the last row_partition_tensor
268 // assuming it is a VALUE_ROWIDS type.
269 // TODO(martinz): Although we normally don't know the first dimension of the
270 // output, we could infer it from the first dimension of the first
271 // row_partition_tensor if it is ROW_SPLITS type.
272 // TODO(martinz): If the shape is provided, but the value_shape has missing
273 // dimensions, we can check the default_value_shape against the shape.
274 TensorShapeProto output_shape;
275 TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
276 ragged_rank, shape, value_shape, &output_shape));
277
278 ShapeHandle output_shape_handle;
279 TF_RETURN_IF_ERROR(
280 c->MakeShapeFromShapeProto(output_shape, &output_shape_handle));
281 c->set_output(0, output_shape_handle);
282 return Status::OK();
283 }
284
285 } // namespace tensorflow
286