1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18
19 namespace tensorflow {
20
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24
25 Status RaggedGatherShapeFn(InferenceContext* c);
26
27 //==============================================================================
28 // Registered Ops
29 //==============================================================================
30
31 REGISTER_OP("RaggedGather")
32 .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits")
33 .Input("params_dense_values: Tvalues")
34 .Input("indices: Tindices")
35 .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits")
36 .Output("output_dense_values: Tvalues")
37 .Attr("Tvalues: type")
38 .Attr("Tindices: {int32, int64}")
39 .Attr("Tsplits: {int32, int64} = DT_INT64")
40 .Attr("PARAMS_RAGGED_RANK: int >= 1")
41 .Attr("OUTPUT_RAGGED_RANK: int >= 0")
42 .SetShapeFn(RaggedGatherShapeFn);
43
44 REGISTER_OP("RaggedCross")
45 .Input("ragged_values: ragged_values_types")
46 .Input("ragged_row_splits: ragged_splits_types")
47 .Input("sparse_indices: Nsparse * int64")
48 .Input("sparse_values: sparse_values_types")
49 .Input("sparse_shape: Nsparse * int64")
50 .Input("dense_inputs: dense_types")
51 .Output("output_values: out_values_type")
52 .Output("output_row_splits: out_row_splits_type")
53 .Attr("Nsparse: int >= 0")
54 .Attr("input_order: string")
55 .Attr("hashed_output: bool")
56 .Attr("num_buckets: int >= 0")
57 .Attr("hash_key: int")
58 .Attr("ragged_values_types: list({int64, string}) >= 0")
59 .Attr("ragged_splits_types: list({int32, int64}) >= 0")
60 .Attr("sparse_values_types: list({int64, string}) >= 0")
61 .Attr("dense_types: list({int64, string}) >= 0")
62 .Attr("out_values_type: {int64, string}")
63 .Attr("out_row_splits_type: {int32, int64}")
__anon9deb1dc20102(shape_inference::InferenceContext* c) 64 .SetShapeFn([](shape_inference::InferenceContext* c) {
65 std::vector<DataType> ragged_values_types;
66 std::vector<DataType> ragged_splits_types;
67 std::vector<DataType> dense_types;
68
69 TF_RETURN_IF_ERROR(
70 c->GetAttr("ragged_values_types", &ragged_values_types));
71 TF_RETURN_IF_ERROR(
72 c->GetAttr("ragged_splits_types", &ragged_splits_types));
73 TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
74
75 int num_ragged = ragged_values_types.size();
76 if (num_ragged != ragged_splits_types.size()) {
77 return errors::InvalidArgument(
78 "Parameters `values` and `row_splits` must be the same length");
79 }
80
81 int num_sparse;
82 TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
83
84 ShapeHandle out_values = c->UnknownShapeOfRank(1);
85 ShapeHandle out_splits = c->UnknownShapeOfRank(1);
86
87 // Merge the shapes of row_splits from ragged inputs. (This is one plus
88 // the batch size.)
89 int ragged_splits_start = num_ragged;
90 for (int i = 0; i < ragged_splits_types.size(); ++i) {
91 ShapeHandle row_splits = c->input(i + ragged_splits_start);
92 if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
93 return errors::InvalidArgument(
94 "inputs must all have the same batch dimension size.");
95 }
96 }
97
98 // Merge the batch size of each dense input into out_splits.
99 int dense_start = num_ragged * 2 + num_sparse * 3;
100 for (int i = 0; i < dense_types.size(); ++i) {
101 ShapeHandle dense_input = c->input(i + dense_start);
102 int64 batch_size = c->Value(c->Dim(dense_input, 0));
103 if (batch_size != InferenceContext::kUnknownDim) {
104 ShapeHandle row_splits = c->Vector(batch_size + 1);
105 if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
106 return errors::InvalidArgument(
107 "inputs must all have the same batch dimension size.");
108 }
109 }
110 }
111
112 c->set_output(0, out_values);
113 c->set_output(1, out_splits);
114 return Status::OK();
115 });
116
117 //==============================================================================
118 // Shape Functions
119 //==============================================================================
120
RaggedGatherShapeFn(InferenceContext * c)121 Status RaggedGatherShapeFn(InferenceContext* c) {
122 int num_splits;
123 int64 PARAMS_RAGGED_RANK;
124 TF_RETURN_IF_ERROR(
125 c->GetAttr<int64>("PARAMS_RAGGED_RANK", &PARAMS_RAGGED_RANK));
126 TF_RETURN_IF_ERROR(c->GetAttr<int>("OUTPUT_RAGGED_RANK", &num_splits));
127
128 // Check rank of `indices`.
129 ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1);
130 TF_RETURN_IF_ERROR(
131 c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices));
132
133 // Check that all params_nested_splits have rank 1.
134 for (int64 i = 0; i < PARAMS_RAGGED_RANK; ++i) {
135 ShapeHandle splits = c->input(i);
136 TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
137 }
138
139 // Check that `params_dense_values` has rank>=1.
140 ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK);
141 TF_RETURN_IF_ERROR(
142 c->WithRankAtLeast(params_dense_values, 1, ¶ms_dense_values));
143
144 // Set the rank for the `splits` outputs.
145 for (int i = 0; i < num_splits; ++i) {
146 c->set_output(i, c->UnknownShapeOfRank(1));
147 }
148
149 // Calculate the `values` shape.
150 ShapeHandle value = c->UnknownShape();
151 ShapeHandle values = c->UnknownShape();
152 TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value));
153 TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values));
154 c->set_output(num_splits, values);
155
156 return Status::OK();
157 }
158
159 } // namespace tensorflow
160