• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/fake_input.h"
17 #include "tensorflow/core/framework/node_def_builder.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/framework/shape_inference_testutil.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/kernels/ops_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 class RaggedTensorToSparseTest : public ::tensorflow::OpsTestBase {
32  protected:
33   static constexpr int kSparseIndicesOutput = 0;
34   static constexpr int kSparseValuesOutput = 1;
35   static constexpr int kSparseDenseShapeOutput = 2;
36   // Builds the tensorflow test graph for the RaggedTensorToSparse op, and
37   // populates the `splits` input with the given values.
38   template <typename T>
BuildRaggedTensorToSparseGraph(const std::vector<std::vector<int64>> & rt_nested_splits,const TensorShape & rt_dense_values_shape,const std::vector<T> & rt_dense_values)39   void BuildRaggedTensorToSparseGraph(
40       const std::vector<std::vector<int64>>& rt_nested_splits,
41       const TensorShape& rt_dense_values_shape,
42       const std::vector<T>& rt_dense_values) {
43     const auto& dtype = DataTypeToEnum<T>::v();
44     int64_t num_splits = rt_nested_splits.size();
45     TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToSparse")
46                      .Input(FakeInput(num_splits))  // rt_nested_splits
47                      .Input(FakeInput(dtype))       // rt_dense_values
48                      .Attr("RAGGED_RANK", num_splits)
49                      .Attr("T", dtype)
50                      .Finalize(node_def()));
51     TF_ASSERT_OK(InitOp());
52     for (const auto& splits : rt_nested_splits) {
53       int64_t splits_size = splits.size();
54       AddInputFromArray<int64>(TensorShape({splits_size}), splits);
55     }
56     AddInputFromArray<T>(rt_dense_values_shape, rt_dense_values);
57   }
58 };
59 
TEST_F(RaggedTensorToSparseTest,OneSplits_Values1D)60 TEST_F(RaggedTensorToSparseTest, OneSplits_Values1D) {
61   // ragged_tensor=[[1, 2, 3], [], [4, 5], [6]]
62   BuildRaggedTensorToSparseGraph<int>({{0, 3, 3, 5, 6}},    // splits
63                                       TensorShape({6}),     // values.shape
64                                       {1, 2, 3, 4, 5, 6});  // values
65   TF_ASSERT_OK(RunOpKernel());
66   test::ExpectTensorEqual<int64>(
67       *GetOutput(kSparseIndicesOutput),
68       test::AsTensor<int64>({0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 3, 0}, {6, 2}));
69   test::ExpectTensorEqual<int>(*GetOutput(kSparseValuesOutput),
70                                test::AsTensor<int>({1, 2, 3, 4, 5, 6}));
71   test::ExpectTensorEqual<int64>(*GetOutput(kSparseDenseShapeOutput),
72                                  test::AsTensor<int64>({4, 3}));
73 }
74 
TEST_F(RaggedTensorToSparseTest,EmptyRows)75 TEST_F(RaggedTensorToSparseTest, EmptyRows) {
76   // Empty rows at the beginning, middle, and end of the RaggedTensor.
77   // ragged_tensor=[[], [1, 2, 3, 4], [], [5, 6], []]
78   BuildRaggedTensorToSparseGraph<int>({{0, 0, 4, 4, 6, 6}},  // splits
79                                       TensorShape({6}),      // values.shape
80                                       {1, 2, 3, 4, 5, 6});   // values
81   TF_ASSERT_OK(RunOpKernel());
82   test::ExpectTensorEqual<int64>(
83       *GetOutput(kSparseIndicesOutput),
84       test::AsTensor<int64>({1, 0, 1, 1, 1, 2, 1, 3, 3, 0, 3, 1}, {6, 2}));
85   test::ExpectTensorEqual<int>(*GetOutput(kSparseValuesOutput),
86                                test::AsTensor<int>({1, 2, 3, 4, 5, 6}));
87   test::ExpectTensorEqual<int64>(*GetOutput(kSparseDenseShapeOutput),
88                                  test::AsTensor<int64>({5, 4}));
89 }
90 
TEST_F(RaggedTensorToSparseTest,OneSplits_Values2D)91 TEST_F(RaggedTensorToSparseTest, OneSplits_Values2D) {
92   // ragged_tensor=[[[1, 2], [3, 4], [5, 6]], [], [[7, 8], [9, 10]], [[11, 12]]]
93   BuildRaggedTensorToSparseGraph<int>(
94       {{0, 3, 3, 5, 6}},                         // splits
95       TensorShape({6, 2}),                       // values.shape
96       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});  // values
97   TF_ASSERT_OK(RunOpKernel());
98   std::vector<int64> expected_splits_12_3 = {
99       0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 2, 0, 0, 2, 1,
100       2, 0, 0, 2, 0, 1, 2, 1, 0, 2, 1, 1, 3, 0, 0, 3, 0, 1};
101   std::vector<int> expected_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
102   test::ExpectTensorEqual<int64>(
103       *GetOutput(kSparseIndicesOutput),
104       test::AsTensor<int64>(expected_splits_12_3, {12, 3}));
105   test::ExpectTensorEqual<int>(*GetOutput(kSparseValuesOutput),
106                                test::AsTensor<int>(expected_values));
107   test::ExpectTensorEqual<int64>(*GetOutput(kSparseDenseShapeOutput),
108                                  test::AsTensor<int64>({4, 3, 2}));
109 }
110 
TEST_F(RaggedTensorToSparseTest,TwoSplits_Values1D)111 TEST_F(RaggedTensorToSparseTest, TwoSplits_Values1D) {
112   // ragged_tensor =
113   //        0             1           2
114   // -+--------------------------------------
115   // 0| [[ [x],         [x x],       [] ],
116   // 1|  [                              ],
117   // 2|  [ [x x x x x], [x x x]         ],
118   // 3|  [ [],          [x x x x]       ]]
119   BuildRaggedTensorToSparseGraph<int>(
120       {{0, 3, 3, 5, 7}, {0, 1, 3, 3, 8, 11, 11, 15}},        // splits
121       TensorShape({15}),                                     // values.shape
122       {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});  // values
123   TF_ASSERT_OK(RunOpKernel());
124   std::vector<int64> expected_splits_15_3 = {
125       0, 0, 0, 0, 1, 0, 0, 1, 1, 2, 0, 0, 2, 0, 1, 2, 0, 2, 2, 0, 3, 2, 0,
126       4, 2, 1, 0, 2, 1, 1, 2, 1, 2, 3, 1, 0, 3, 1, 1, 3, 1, 2, 3, 1, 3};
127   std::vector<int> expected_values = {1, 2,  3,  4,  5,  6,  7, 8,
128                                       9, 10, 11, 12, 13, 14, 15};
129   test::ExpectTensorEqual<int>(*GetOutput(kSparseValuesOutput),
130                                test::AsTensor<int>(expected_values));
131   test::ExpectTensorEqual<int64>(
132       *GetOutput(kSparseIndicesOutput),
133       test::AsTensor<int64>(expected_splits_15_3, {15, 3}));
134   test::ExpectTensorEqual<int64>(*GetOutput(kSparseDenseShapeOutput),
135                                  test::AsTensor<int64>({4, 3, 5}));
136 }
137 
TEST_F(RaggedTensorToSparseTest,ShapeFn)138 TEST_F(RaggedTensorToSparseTest, ShapeFn) {
139   // RaggedSplitsToIndices(rt_nested_splits+, rt_dense_values)
140   //     -> [sparse_indices, sparse_values, sparse_dense_shape]
141   // The output shape will always have the following form:
142   //     [nvals, dense_dims];[nvals];[dense_dims]
143   ShapeInferenceTestOp op("RaggedTensorToSparse");
144 
145   // Tests with len(rt_nested_splits)==0.
146   (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
147   INFER_ERROR("Requires RAGGED_RANK>0", op, "?");
148 
149   // Tests with len(rt_nested_splits)==1.
150   (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
151   INFER_OK(op, "?;?", "[?,?];[?];[?]");          // nvals=?, dense_dims=?
152   INFER_OK(op, "?;[?]", "[?,2];[?];[2]");        // nvals=?, dense_dims=2
153   INFER_OK(op, "?;[?,?]", "[?,3];[?];[3]");      // nvals=?, dense_dims=3
154   INFER_OK(op, "[?];[5]", "[5,2];[5];[2]");      // nvals=5, dense_dims=2
155   INFER_OK(op, "[?];[5,2]", "[10,3];[10];[3]");  // nvals=10, dense_dims=3
156   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
157   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[5,5];?");
158   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
159 
160   // Tests with len(rt_nested_splits)==2
161   (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(2);
162   INFER_OK(op, "?;?;?", "[?,?];[?];[?]");            // nvals=?, dense_dims=?
163   INFER_OK(op, "?;?;[?]", "[?,3];[?];[3]");          // nvals=?, dense_dims=3
164   INFER_OK(op, "?;?;[?,?]", "[?,4];[?];[4]");        // nvals=?, dense_dims=4
165   INFER_OK(op, "[?];[?];[5]", "[5,3];[5];[3]");      // nvals=5, dense_dims=3
166   INFER_OK(op, "[?];[?];[5,2]", "[10,4];[10];[4]");  // nvals=10, dense_dims=4
167   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[5,5];?");
168 
169   // Tests with len(rt_nested_splits)==3
170   (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(3);
171   INFER_OK(op, "?;?;?;?", "[?,?];[?];[?]");    // nvals=?, dense_dims=?
172   INFER_OK(op, "?;?;?;[?]", "[?,4];[?];[4]");  // nvals=?, dense_dims=4
173   INFER_OK(op, "?;?;?;[5]", "[5,4];[5];[4]");  // nvals=5, dense_dims=4
174 }
175 
TEST_F(RaggedTensorToSparseTest,NoSplits)176 TEST_F(RaggedTensorToSparseTest, NoSplits) {
177   const auto& dtype = DataTypeToEnum<int>::v();
178   TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToSparse")
179                    .Input(FakeInput(0))
180                    .Input(FakeInput(dtype))
181                    .Attr("RAGGED_RANK", 0)
182                    .Attr("T", dtype)
183                    .Finalize(node_def()));
184   EXPECT_TRUE(absl::StartsWith(
185       InitOp().error_message(),
186       "Value for attr 'RAGGED_RANK' of 0 must be at least minimum 1"));
187 }
188 
TEST_F(RaggedTensorToSparseTest,InvalidArg_BadSplitStart)189 TEST_F(RaggedTensorToSparseTest, InvalidArg_BadSplitStart) {
190   BuildRaggedTensorToSparseGraph<int>({{5, 7, 10}},      // splits
191                                       TensorShape({0}),  // values.shape
192                                       {});               // values
193   EXPECT_EQ("First value of ragged splits must be 0.",
194             RunOpKernel().error_message());
195 }
196 
TEST_F(RaggedTensorToSparseTest,InvalidArg_BadSplitLengths1)197 TEST_F(RaggedTensorToSparseTest, InvalidArg_BadSplitLengths1) {
198   BuildRaggedTensorToSparseGraph<int>({{0, 5}, {0, 2, 4, 6}},  // splits
199                                       TensorShape({0}),        // values.shape
200                                       {});                     // values
201   EXPECT_EQ(
202       "Final value of ragged splits must match the length "
203       "the corresponding ragged values.",
204       RunOpKernel().error_message());
205 }
206 
TEST_F(RaggedTensorToSparseTest,InvalidArg_BadSplitLengths2)207 TEST_F(RaggedTensorToSparseTest, InvalidArg_BadSplitLengths2) {
208   BuildRaggedTensorToSparseGraph<int>({{0, 5}},          // splits
209                                       TensorShape({0}),  // values.shape
210                                       {});               // values
211   EXPECT_EQ(
212       "Final value of ragged splits must match the length "
213       "the corresponding ragged values.",
214       RunOpKernel().error_message());
215 }
216 
TEST_F(RaggedTensorToSparseTest,InvalidArg_EmptySplits)217 TEST_F(RaggedTensorToSparseTest, InvalidArg_EmptySplits) {
218   BuildRaggedTensorToSparseGraph<int>({{}},              // splits
219                                       TensorShape({0}),  // values.shape
220                                       {});               // values
221   EXPECT_EQ("ragged splits may not be empty.", RunOpKernel().error_message());
222 }
223 
224 }  // namespace
225 }  // namespace tensorflow
226