• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/fake_input.h"
17 #include "tensorflow/core/framework/node_def_builder.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/framework/shape_inference_testutil.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/kernels/ops_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 class RaggedRangeOpTest : public ::tensorflow::OpsTestBase {
31  protected:
32   // Indices of output tensors.
33   static constexpr int kSplitsOutput = 0;
34   static constexpr int kValuesOutput = 1;
35 
36   // Builds the tensorflow test graph for the RaggedRange op.
37   template <typename T>
BuildRaggedRangeGraph()38   void BuildRaggedRangeGraph() {
39     const auto& dtype = DataTypeToEnum<T>::v();
40     TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedRange")
41                      .Input(FakeInput(dtype))  // starts
42                      .Input(FakeInput(dtype))  // limits
43                      .Input(FakeInput(dtype))  // deltas
44                      .Attr("T", dtype)
45                      .Finalize(node_def()));
46     TF_ASSERT_OK(InitOp());
47   }
48 };
49 
TEST_F(RaggedRangeOpTest,IntValues)50 TEST_F(RaggedRangeOpTest, IntValues) {
51   BuildRaggedRangeGraph<int>();
52   AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5});   // starts
53   AddInputFromArray<int>(TensorShape({4}), {8, 7, 8, 1});   // limits
54   AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1});  // deltas
55   TF_ASSERT_OK(RunOpKernel());
56 
57   // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]
58   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
59                                  test::AsTensor<int64>({0, 4, 6, 6, 10}));
60   test::ExpectTensorEqual<int>(
61       *GetOutput(kValuesOutput),
62       test::AsTensor<int>({0, 2, 4, 6, 5, 6, 5, 4, 3, 2}));
63 }
64 
TEST_F(RaggedRangeOpTest,FloatValues)65 TEST_F(RaggedRangeOpTest, FloatValues) {
66   BuildRaggedRangeGraph<float>();
67   AddInputFromArray<float>(TensorShape({4}), {0, 5, 8, 5});   // starts
68   AddInputFromArray<float>(TensorShape({4}), {8, 7, 8, 1});   // limits
69   AddInputFromArray<float>(TensorShape({4}), {2, 1, 1, -1});  // deltas
70   TF_ASSERT_OK(RunOpKernel());
71 
72   // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]
73   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
74                                  test::AsTensor<int64>({0, 4, 6, 6, 10}));
75   test::ExpectTensorNear<float>(
76       *GetOutput(kValuesOutput),
77       test::AsTensor<float>({0, 2, 4, 6, 5, 6, 5, 4, 3, 2}), 0.1);
78 }
79 
TEST_F(RaggedRangeOpTest,BroadcastDeltas)80 TEST_F(RaggedRangeOpTest, BroadcastDeltas) {
81   BuildRaggedRangeGraph<int>();
82   AddInputFromArray<int>(TensorShape({3}), {0, 5, 8});  // starts
83   AddInputFromArray<int>(TensorShape({3}), {8, 7, 8});  // limits
84   AddInputFromArray<int>(TensorShape({}), {1});         // deltas
85   TF_ASSERT_OK(RunOpKernel());
86 
87   // Expected: [[0, 1, 2, 3, 4, 5, 6, 7], [5, 6], []]
88   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
89                                  test::AsTensor<int64>({0, 8, 10, 10}));
90   test::ExpectTensorEqual<int>(
91       *GetOutput(kValuesOutput),
92       test::AsTensor<int>({0, 1, 2, 3, 4, 5, 6, 7, 5, 6}));
93 }
94 
TEST_F(RaggedRangeOpTest,BroadcastLimitsAndDeltas)95 TEST_F(RaggedRangeOpTest, BroadcastLimitsAndDeltas) {
96   BuildRaggedRangeGraph<int>();
97   AddInputFromArray<int>(TensorShape({}), {0});         // starts
98   AddInputFromArray<int>(TensorShape({3}), {3, 0, 2});  // limits
99   AddInputFromArray<int>(TensorShape({}), {1});         // deltas
100   TF_ASSERT_OK(RunOpKernel());
101 
102   // Expected: [[0, 1, 2], [], [0, 1]]
103   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
104                                  test::AsTensor<int64>({0, 3, 3, 5}));
105   test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
106                                test::AsTensor<int>({0, 1, 2, 0, 1}));
107 }
108 
TEST_F(RaggedRangeOpTest,BroadcastStartsAndLimits)109 TEST_F(RaggedRangeOpTest, BroadcastStartsAndLimits) {
110   BuildRaggedRangeGraph<int>();
111   AddInputFromArray<int>(TensorShape({}), {0});         // starts
112   AddInputFromArray<int>(TensorShape({}), {12});        // limits
113   AddInputFromArray<int>(TensorShape({3}), {3, 4, 5});  // deltas
114   TF_ASSERT_OK(RunOpKernel());
115 
116   // Expected: [[0, 3, 6, 9], [0, 4, 8], [0, 5, 10]]]
117   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
118                                  test::AsTensor<int64>({0, 4, 7, 10}));
119   test::ExpectTensorEqual<int>(
120       *GetOutput(kValuesOutput),
121       test::AsTensor<int>({0, 3, 6, 9, 0, 4, 8, 0, 5, 10}));
122 }
123 
TEST_F(RaggedRangeOpTest,AllScalarInputs)124 TEST_F(RaggedRangeOpTest, AllScalarInputs) {
125   BuildRaggedRangeGraph<int>();
126   AddInputFromArray<int>(TensorShape({}), {0});  // starts
127   AddInputFromArray<int>(TensorShape({}), {5});  // limits
128   AddInputFromArray<int>(TensorShape({}), {1});  // deltas
129   TF_ASSERT_OK(RunOpKernel());
130 
131   // Expected: [[0, 1, 2, 3, 4]
132   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
133                                  test::AsTensor<int64>({0, 5}));
134   test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
135                                test::AsTensor<int>({0, 1, 2, 3, 4}));
136 }
137 
TEST_F(RaggedRangeOpTest,InvalidArgsStarts)138 TEST_F(RaggedRangeOpTest, InvalidArgsStarts) {
139   BuildRaggedRangeGraph<int>();
140   AddInputFromArray<int>(TensorShape({4, 1}), {0, 5, 8, 5});  // starts
141   AddInputFromArray<int>(TensorShape({4}), {8, 7, 8, 1});     // limits
142   AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1});    // deltas
143   EXPECT_EQ("starts must be a scalar or vector", RunOpKernel().error_message());
144 }
145 
TEST_F(RaggedRangeOpTest,InvalidArgsLimits)146 TEST_F(RaggedRangeOpTest, InvalidArgsLimits) {
147   BuildRaggedRangeGraph<int>();
148   AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5});     // starts
149   AddInputFromArray<int>(TensorShape({4, 1}), {8, 7, 8, 1});  // limits
150   AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1});    // deltas
151   EXPECT_EQ("limits must be a scalar or vector", RunOpKernel().error_message());
152 }
153 
TEST_F(RaggedRangeOpTest,InvalidArgsDeltas)154 TEST_F(RaggedRangeOpTest, InvalidArgsDeltas) {
155   BuildRaggedRangeGraph<int>();
156   AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5});      // starts
157   AddInputFromArray<int>(TensorShape({4}), {8, 7, 8, 1});      // limits
158   AddInputFromArray<int>(TensorShape({4, 1}), {2, 1, 1, -1});  // deltas
159   EXPECT_EQ("deltas must be a scalar or vector", RunOpKernel().error_message());
160 }
161 
TEST_F(RaggedRangeOpTest,InvalidArgsShapeMismatch)162 TEST_F(RaggedRangeOpTest, InvalidArgsShapeMismatch) {
163   BuildRaggedRangeGraph<int>();
164   AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5});   // starts
165   AddInputFromArray<int>(TensorShape({3}), {7, 8, 1});      // limits
166   AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1});  // deltas
167   EXPECT_EQ("starts, limits, and deltas must have the same shape",
168             RunOpKernel().error_message());
169 }
170 
TEST_F(RaggedRangeOpTest,InvalidArgsZeroDelta)171 TEST_F(RaggedRangeOpTest, InvalidArgsZeroDelta) {
172   BuildRaggedRangeGraph<int>();
173   AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5});   // starts
174   AddInputFromArray<int>(TensorShape({4}), {7, 8, 8, 1});   // limits
175   AddInputFromArray<int>(TensorShape({4}), {2, 1, 0, -1});  // deltas
176   EXPECT_EQ("Requires delta != 0", RunOpKernel().error_message());
177 }
178 
TEST_F(RaggedRangeOpTest,EmptyRangePositiveDelta)179 TEST_F(RaggedRangeOpTest, EmptyRangePositiveDelta) {
180   BuildRaggedRangeGraph<int>();
181   AddInputFromArray<int>(TensorShape({2}), {0, 5});  // starts
182   AddInputFromArray<int>(TensorShape({2}), {5, 0});  // limits
183   AddInputFromArray<int>(TensorShape({}), {2});      // deltas
184   TF_ASSERT_OK(RunOpKernel());
185 
186   // Expected: [[0, 2, 4], []]
187   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
188                                  test::AsTensor<int64>({0, 3, 3}));
189   test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
190                                test::AsTensor<int>({0, 2, 4}));
191 }
192 
TEST_F(RaggedRangeOpTest,EmptyRangeNegativeDelta)193 TEST_F(RaggedRangeOpTest, EmptyRangeNegativeDelta) {
194   BuildRaggedRangeGraph<int>();
195   AddInputFromArray<int>(TensorShape({2}), {0, 5});  // starts
196   AddInputFromArray<int>(TensorShape({2}), {5, 0});  // limits
197   AddInputFromArray<int>(TensorShape({}), {-2});     // deltas
198   TF_ASSERT_OK(RunOpKernel());
199 
200   // Expected: [[], [5, 3, 1]]
201   test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
202                                  test::AsTensor<int64>({0, 0, 3}));
203   test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
204                                test::AsTensor<int>({5, 3, 1}));
205 }
206 
TEST_F(RaggedRangeOpTest,ShapeFn)207 TEST_F(RaggedRangeOpTest, ShapeFn) {
208   // RaggedRange(starts, limits, deltas) -> [splits, values]
209   ShapeInferenceTestOp op("RaggedRange");
210   INFER_OK(op, "?;?;?", "[?];[?]");
211   INFER_OK(op, "[3];[3];[3]", "[4];[?]");
212   INFER_OK(op, "[3];[3];[]", "[4];[?]");  // broadcast deltas
213   INFER_OK(op, "[3];[];[3]", "[4];[?]");  // broadcast limits
214   INFER_OK(op, "[];[3];[3]", "[4];[?]");  // broadcast starts
215   INFER_OK(op, "[];[];[]", "[2];[?]");    // degenerate case: all scalar inputs
216   INFER_ERROR("Shape must be at most rank 1 but is rank 2", op,
217               "[5,5];[5];[5]");
218   INFER_ERROR("Shape must be at most rank 1 but is rank 2", op,
219               "[5];[5,5];[5]");
220   INFER_ERROR("Shape must be at most rank 1 but is rank 2", op,
221               "[5];[5];[5,5]");
222   INFER_ERROR("Dimensions must be equal, but are 4 and 3", op, "[3];[4];[3]");
223 }
224 
225 }  // namespace
226 }  // namespace tensorflow
227