• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 Status RaggedGatherShapeFn(InferenceContext* c);
26 
27 //==============================================================================
28 // Registered Ops
29 //==============================================================================
30 
31 REGISTER_OP("RaggedGather")
32     .Input("params_nested_splits: PARAMS_RAGGED_RANK * Tsplits")
33     .Input("params_dense_values: Tvalues")
34     .Input("indices: Tindices")
35     .Output("output_nested_splits: OUTPUT_RAGGED_RANK * Tsplits")
36     .Output("output_dense_values: Tvalues")
37     .Attr("Tvalues: type")
38     .Attr("Tindices: {int32, int64}")
39     .Attr("Tsplits: {int32, int64} = DT_INT64")
40     .Attr("PARAMS_RAGGED_RANK: int >= 1")
41     .Attr("OUTPUT_RAGGED_RANK: int >= 0")
42     .SetShapeFn(RaggedGatherShapeFn);
43 
44 REGISTER_OP("RaggedCross")
45     .Input("ragged_values: ragged_values_types")
46     .Input("ragged_row_splits: ragged_splits_types")
47     .Input("sparse_indices: Nsparse * int64")
48     .Input("sparse_values: sparse_values_types")
49     .Input("sparse_shape: Nsparse * int64")
50     .Input("dense_inputs: dense_types")
51     .Output("output_values: out_values_type")
52     .Output("output_row_splits: out_row_splits_type")
53     .Attr("Nsparse: int >= 0")
54     .Attr("input_order: string")
55     .Attr("hashed_output: bool")
56     .Attr("num_buckets: int >= 0")
57     .Attr("hash_key: int")
58     .Attr("ragged_values_types: list({int64, string}) >= 0")
59     .Attr("ragged_splits_types: list({int32, int64}) >= 0")
60     .Attr("sparse_values_types: list({int64, string}) >= 0")
61     .Attr("dense_types: list({int64, string}) >= 0")
62     .Attr("out_values_type: {int64, string}")
63     .Attr("out_row_splits_type: {int32, int64}")
__anon9deb1dc20102(shape_inference::InferenceContext* c) 64     .SetShapeFn([](shape_inference::InferenceContext* c) {
65       std::vector<DataType> ragged_values_types;
66       std::vector<DataType> ragged_splits_types;
67       std::vector<DataType> dense_types;
68 
69       TF_RETURN_IF_ERROR(
70           c->GetAttr("ragged_values_types", &ragged_values_types));
71       TF_RETURN_IF_ERROR(
72           c->GetAttr("ragged_splits_types", &ragged_splits_types));
73       TF_RETURN_IF_ERROR(c->GetAttr("dense_types", &dense_types));
74 
75       int num_ragged = ragged_values_types.size();
76       if (num_ragged != ragged_splits_types.size()) {
77         return errors::InvalidArgument(
78             "Parameters `values` and `row_splits` must be the same length");
79       }
80 
81       int num_sparse;
82       TF_RETURN_IF_ERROR(c->GetAttr("Nsparse", &num_sparse));
83 
84       ShapeHandle out_values = c->UnknownShapeOfRank(1);
85       ShapeHandle out_splits = c->UnknownShapeOfRank(1);
86 
87       // Merge the shapes of row_splits from ragged inputs.  (This is one plus
88       // the batch size.)
89       int ragged_splits_start = num_ragged;
90       for (int i = 0; i < ragged_splits_types.size(); ++i) {
91         ShapeHandle row_splits = c->input(i + ragged_splits_start);
92         if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
93           return errors::InvalidArgument(
94               "inputs must all have the same batch dimension size.");
95         }
96       }
97 
98       // Merge the batch size of each dense input into out_splits.
99       int dense_start = num_ragged * 2 + num_sparse * 3;
100       for (int i = 0; i < dense_types.size(); ++i) {
101         ShapeHandle dense_input = c->input(i + dense_start);
102         int64 batch_size = c->Value(c->Dim(dense_input, 0));
103         if (batch_size != InferenceContext::kUnknownDim) {
104           ShapeHandle row_splits = c->Vector(batch_size + 1);
105           if (!c->Merge(out_splits, row_splits, &out_splits).ok()) {
106             return errors::InvalidArgument(
107                 "inputs must all have the same batch dimension size.");
108           }
109         }
110       }
111 
112       c->set_output(0, out_values);
113       c->set_output(1, out_splits);
114       return Status::OK();
115     });
116 
117 //==============================================================================
118 // Shape Functions
119 //==============================================================================
120 
RaggedGatherShapeFn(InferenceContext * c)121 Status RaggedGatherShapeFn(InferenceContext* c) {
122   int num_splits;
123   int64 PARAMS_RAGGED_RANK;
124   TF_RETURN_IF_ERROR(
125       c->GetAttr<int64>("PARAMS_RAGGED_RANK", &PARAMS_RAGGED_RANK));
126   TF_RETURN_IF_ERROR(c->GetAttr<int>("OUTPUT_RAGGED_RANK", &num_splits));
127 
128   // Check rank of `indices`.
129   ShapeHandle indices = c->input(PARAMS_RAGGED_RANK + 1);
130   TF_RETURN_IF_ERROR(
131       c->WithRank(indices, num_splits - PARAMS_RAGGED_RANK + 1, &indices));
132 
133   // Check that all params_nested_splits have rank 1.
134   for (int64 i = 0; i < PARAMS_RAGGED_RANK; ++i) {
135     ShapeHandle splits = c->input(i);
136     TF_RETURN_IF_ERROR(c->WithRank(splits, 1, &splits));
137   }
138 
139   // Check that `params_dense_values` has rank>=1.
140   ShapeHandle params_dense_values = c->input(PARAMS_RAGGED_RANK);
141   TF_RETURN_IF_ERROR(
142       c->WithRankAtLeast(params_dense_values, 1, &params_dense_values));
143 
144   // Set the rank for the `splits` outputs.
145   for (int i = 0; i < num_splits; ++i) {
146     c->set_output(i, c->UnknownShapeOfRank(1));
147   }
148 
149   // Calculate the `values` shape.
150   ShapeHandle value = c->UnknownShape();
151   ShapeHandle values = c->UnknownShape();
152   TF_RETURN_IF_ERROR(c->Subshape(params_dense_values, 1, &value));
153   TF_RETURN_IF_ERROR(c->Concatenate(c->UnknownShapeOfRank(1), value, &values));
154   c->set_output(num_splits, values);
155 
156   return Status::OK();
157 }
158 
159 }  // namespace tensorflow
160