• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 #include "tensorflow/core/util/ragged_to_dense_util.h"
19 
20 namespace tensorflow {
21 
22 using errors::InvalidArgument;
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 namespace {
ValidateRowPartitionTypesAndShapes(const std::vector<RowPartitionType> & row_partition_types,InferenceContext * c)28 tensorflow::Status ValidateRowPartitionTypesAndShapes(
29     const std::vector<RowPartitionType>& row_partition_types,
30     InferenceContext* c) {
31   // Note: the allowed types may be extended in the future.
32   for (RowPartitionType row_partition_type : row_partition_types) {
33     switch (row_partition_type) {
34       case RowPartitionType::FIRST_DIM_SIZE:
35       case RowPartitionType::VALUE_ROWIDS:
36       case RowPartitionType::ROW_SPLITS:
37         break;
38       default:
39         return InvalidArgument("Unsupported partition type: ",
40                                RowPartitionTypeToString(row_partition_type));
41     }
42   }
43 
44   if (row_partition_types.empty()) {
45     return InvalidArgument("Partition info types should not be empty");
46   }
47   for (int i = 1; i < row_partition_types.size(); ++i) {
48     if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
49       return InvalidArgument("FIRST_DIM_SIZE must be first");
50     }
51   }
52   if (row_partition_types[0] == RowPartitionType::FIRST_DIM_SIZE &&
53       (row_partition_types.size() < 2 ||
54        row_partition_types[1] != RowPartitionType::VALUE_ROWIDS)) {
55     return InvalidArgument("FIRST_DIM_SIZE must be followed by VALUE_ROWIDS");
56   }
57   if (row_partition_types[0] == RowPartitionType::VALUE_ROWIDS) {
58     return InvalidArgument("VALUE_ROWIDS cannot be first");
59   }
60 
61   int num_row_partition_tensors;
62   TF_RETURN_IF_ERROR(
63       c->GetAttr("num_row_partition_tensors", &num_row_partition_tensors));
64   if (num_row_partition_tensors != row_partition_types.size()) {
65     return InvalidArgument(
66         "Number of row partition tensors (", num_row_partition_tensors,
67         ") does not equal the number of row partition types(",
68         row_partition_types.size(), ").");
69   }
70 
71   for (int i = 0; i < num_row_partition_tensors; ++i) {
72     TensorShapeProto partition_shape;
73     c->ShapeHandleToProto(c->input(3 + i), &partition_shape);
74     if (partition_shape.unknown_rank()) {
75       continue;
76     }
77     if (row_partition_types[i] == RowPartitionType::FIRST_DIM_SIZE) {
78       if (partition_shape.dim_size() != 0) {
79         return InvalidArgument("FIRST_DIM_SIZE must be a scalar.");
80       }
81     } else {
82       if (partition_shape.dim_size() != 1) {
83         return InvalidArgument("Row partition must be a vector.");
84       }
85     }
86   }
87   return tensorflow::Status::OK();
88 }
89 
90 }  // namespace
91 
92 Status RaggedTensorToSparseShapeFn(InferenceContext* c);
93 Status RaggedTensorToVariantShapeFn(InferenceContext* c);
94 Status RaggedTensorFromVariantShapeFn(InferenceContext* c);
95 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c);
96 Status RaggedTensorToTensorShapeFn(InferenceContext* c);
97 
98 //==============================================================================
99 // Registered Ops
100 //==============================================================================
101 
102 REGISTER_OP("RaggedTensorToSparse")
103     .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
104     .Input("rt_dense_values: T")
105     .Output("sparse_indices: int64")
106     .Output("sparse_values: T")
107     .Output("sparse_dense_shape: int64")
108     .Attr("RAGGED_RANK: int >= 1")
109     .Attr("T: type")
110     .Attr("Tsplits: {int32, int64} = DT_INT64")
111     .SetShapeFn(RaggedTensorToSparseShapeFn);
112 
113 REGISTER_OP("RaggedTensorToVariant")
114     .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
115     .Input("rt_dense_values: Tvalues")
116     .Output("encoded_ragged: variant")
117     .Attr("RAGGED_RANK: int >= 0")
118     .Attr("Tvalues: type")
119     .Attr("Tsplits: {int32, int64} = DT_INT64")
120     .Attr("batched_input: bool")
121     .SetShapeFn(RaggedTensorToVariantShapeFn);
122 
123 REGISTER_OP("RaggedTensorFromVariant")
124     .Input("encoded_ragged: variant")
125     .Output("output_nested_splits: output_ragged_rank * Tsplits")
126     .Output("output_dense_values: Tvalues")
127     .Attr("input_ragged_rank: int >= -1")
128     .Attr("output_ragged_rank: int >= 0")
129     .Attr("Tvalues: type")
130     .Attr("Tsplits: {int32, int64} = DT_INT64")
131     .SetShapeFn(RaggedTensorFromVariantShapeFn);
132 
133 REGISTER_OP("RaggedTensorToVariantGradient")
134     .Input("encoded_ragged_grad: variant")
135     .Input("row_splits: Tsplits")
136     .Input("dense_values_shape: int32")
137     .Output("dense_values_grad: Tvalues")
138     .Attr("Tvalues: type")
139     .Attr("Tsplits: {int32, int64} = DT_INT64")
140     .SetShapeFn(RaggedTensorToVariantGradientShapeFn);
141 
142 REGISTER_OP("RaggedTensorToTensor")
143     .Attr("T: type")
144     .Attr("Tindex: {int64, int32}")
145     .Attr("Tshape: {int64, int32}")
146     .Attr("num_row_partition_tensors: int")
147     .Attr("row_partition_types: list(string)")
148     .Input("shape: Tshape")
149     .Input("values: T")
150     .Input("default_value: T")
151     .Input("row_partition_tensors: num_row_partition_tensors * Tindex")
152     .Output("result: T")
153     .SetShapeFn(RaggedTensorToTensorShapeFn);
154 
155 //==============================================================================
156 // Shape Functions
157 //==============================================================================
158 
RaggedTensorToSparseShapeFn(InferenceContext * c)159 Status RaggedTensorToSparseShapeFn(InferenceContext* c) {
160   int64 num_splits;
161   TF_RETURN_IF_ERROR(c->GetAttr<int64>("RAGGED_RANK", &num_splits));
162   // TODO(b/112274756): Allow ragged_rank to be 0.
163   if (num_splits < 1) {
164     return errors::InvalidArgument("Requires RAGGED_RANK>0");
165   }
166   ShapeHandle rt_dense_values = c->input(num_splits);
167   TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
168 
169   // Check that all rt_nested_splits have rank 1.
170   for (int64 i = 0; i < num_splits; ++i) {
171     ShapeHandle splits = c->input(i);
172     TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
173   }
174 
175   DimensionHandle dense_dims =
176       c->RankKnown(rt_dense_values)
177           ? c->MakeDim(c->Rank(rt_dense_values) + num_splits)
178           : c->UnknownDim();
179   DimensionHandle num_values = c->NumElements(rt_dense_values);
180 
181   c->set_output(0, c->Matrix(num_values, dense_dims));  // indices
182   c->set_output(1, c->Vector(num_values));              // values
183   c->set_output(2, c->Vector(dense_dims));              // dense_shape
184 
185   return Status::OK();
186 }
187 
RaggedTensorToVariantShapeFn(InferenceContext * c)188 Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
189   int64 num_splits;
190   TF_RETURN_IF_ERROR(c->GetAttr<int64>("RAGGED_RANK", &num_splits));
191   bool batched;
192   TF_RETURN_IF_ERROR(c->GetAttr<bool>("batched_input", &batched));
193   shape_inference::ShapeHandle rt_dense_values = c->input(num_splits);
194   TF_RETURN_IF_ERROR(c->WithRankAtLeast(rt_dense_values, 1, &rt_dense_values));
195   for (int64 i = 0; i < num_splits; ++i) {
196     shape_inference::ShapeHandle splits = c->input(i);
197     TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
198   }
199   if (batched) {
200     auto num_first_splits = c->Dim(c->input(0), 0);
201     shape_inference::DimensionHandle num_rows;
202     TF_RETURN_IF_ERROR(c->Subtract(num_first_splits, 1, &num_rows));
203     c->set_output(0, c->Vector(num_rows));
204   } else {
205     c->set_output(0, c->Scalar());
206   }
207   if (batched && num_splits == 0) {
208     return errors::InvalidArgument(
209         "ragged_rank=0 is not currently supported when batched_input=true.");
210   }
211   return Status::OK();
212 }
213 
RaggedTensorToVariantGradientShapeFn(InferenceContext * c)214 Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) {
215   ShapeHandle shape;
216   TF_RETURN_IF_ERROR(
217       c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape));
218   c->set_output(0, shape);
219   return Status::OK();
220 }
221 
RaggedTensorFromVariantShapeFn(InferenceContext * c)222 Status RaggedTensorFromVariantShapeFn(InferenceContext* c) {
223   int64 input_ragged_rank;
224   TF_RETURN_IF_ERROR(
225       c->GetAttr<int64>("input_ragged_rank", &input_ragged_rank));
226   int64 output_ragged_rank;
227   TF_RETURN_IF_ERROR(
228       c->GetAttr<int64>("output_ragged_rank", &output_ragged_rank));
229   shape_inference::ShapeHandle encoded_ragged = c->input(0);
230   if (c->RankKnown(encoded_ragged) && input_ragged_rank >= 0) {
231     shape_inference::ShapeHandle unused;
232     TF_RETURN_IF_ERROR(c->WithRank(
233         encoded_ragged, output_ragged_rank - input_ragged_rank, &unused));
234   }
235   for (int64 i = 0; i < output_ragged_rank; i++) {
236     c->set_output(i, c->UnknownShapeOfRank(1));
237   }
238   c->set_output(output_ragged_rank, c->UnknownShape());
239   return Status::OK();
240 }
241 
RaggedTensorToTensorShapeFn(InferenceContext * c)242 tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) {
243   TensorShapeProto shape;
244   {
245     ShapeHandle shape_handle;
246     TF_RETURN_IF_ERROR(
247         c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &shape_handle));
248     c->ShapeHandleToProto(shape_handle, &shape);
249   }
250 
251   std::vector<RowPartitionType> row_partition_types;
252   TF_RETURN_IF_ERROR(GetRowPartitionTypes(c, &row_partition_types));
253   int ragged_rank = GetRaggedRank(row_partition_types);
254   TF_RETURN_IF_ERROR(
255       ValidateRowPartitionTypesAndShapes(row_partition_types, c));
256 
257   TensorShapeProto value_shape;
258   c->ShapeHandleToProto(c->input(1), &value_shape);
259 
260   TensorShapeProto default_value_shape;
261   c->ShapeHandleToProto(c->input(2), &default_value_shape);
262 
263   TF_RETURN_IF_ERROR(
264       ValidateDefaultValueShape(default_value_shape, value_shape));
265 
266   // TODO(martinz): Theoretically, we could check the first dimension of
267   // value_shape against the first dimension of the last row_partition_tensor
268   // assuming it is a VALUE_ROWIDS type.
269   // TODO(martinz): Although we normally don't know the first dimension of the
270   // output, we could infer it from the first dimension of the first
271   // row_partition_tensor if it is ROW_SPLITS type.
272   // TODO(martinz): If the shape is provided, but the value_shape has missing
273   // dimensions, we can check the default_value_shape against the shape.
274   TensorShapeProto output_shape;
275   TF_RETURN_IF_ERROR(CombineRaggedTensorToTensorShapes(
276       ragged_rank, shape, value_shape, &output_shape));
277 
278   ShapeHandle output_shape_handle;
279   TF_RETURN_IF_ERROR(
280       c->MakeShapeFromShapeProto(output_shape, &output_shape_handle));
281   c->set_output(0, output_shape_handle);
282   return Status::OK();
283 }
284 
285 }  // namespace tensorflow
286