1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/fake_input.h"
17 #include "tensorflow/core/framework/node_def_builder.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/framework/shape_inference_testutil.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_shape.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/kernels/ops_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace tensorflow {
28 namespace {
29
30 class RaggedRangeOpTest : public ::tensorflow::OpsTestBase {
31 protected:
32 // Indices of output tensors.
33 static constexpr int kSplitsOutput = 0;
34 static constexpr int kValuesOutput = 1;
35
36 // Builds the tensorflow test graph for the RaggedRange op.
37 template <typename T>
BuildRaggedRangeGraph()38 void BuildRaggedRangeGraph() {
39 const auto& dtype = DataTypeToEnum<T>::v();
40 TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedRange")
41 .Input(FakeInput(dtype)) // starts
42 .Input(FakeInput(dtype)) // limits
43 .Input(FakeInput(dtype)) // deltas
44 .Attr("T", dtype)
45 .Finalize(node_def()));
46 TF_ASSERT_OK(InitOp());
47 }
48 };
49
TEST_F(RaggedRangeOpTest,IntValues)50 TEST_F(RaggedRangeOpTest, IntValues) {
51 BuildRaggedRangeGraph<int>();
52 AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5}); // starts
53 AddInputFromArray<int>(TensorShape({4}), {8, 7, 8, 1}); // limits
54 AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1}); // deltas
55 TF_ASSERT_OK(RunOpKernel());
56
57 // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]
58 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
59 test::AsTensor<int64>({0, 4, 6, 6, 10}));
60 test::ExpectTensorEqual<int>(
61 *GetOutput(kValuesOutput),
62 test::AsTensor<int>({0, 2, 4, 6, 5, 6, 5, 4, 3, 2}));
63 }
64
TEST_F(RaggedRangeOpTest,FloatValues)65 TEST_F(RaggedRangeOpTest, FloatValues) {
66 BuildRaggedRangeGraph<float>();
67 AddInputFromArray<float>(TensorShape({4}), {0, 5, 8, 5}); // starts
68 AddInputFromArray<float>(TensorShape({4}), {8, 7, 8, 1}); // limits
69 AddInputFromArray<float>(TensorShape({4}), {2, 1, 1, -1}); // deltas
70 TF_ASSERT_OK(RunOpKernel());
71
72 // Expected: [[0, 2, 4, 6], [5, 6], [], [5, 4, 3, 2]]
73 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
74 test::AsTensor<int64>({0, 4, 6, 6, 10}));
75 test::ExpectTensorNear<float>(
76 *GetOutput(kValuesOutput),
77 test::AsTensor<float>({0, 2, 4, 6, 5, 6, 5, 4, 3, 2}), 0.1);
78 }
79
TEST_F(RaggedRangeOpTest,BroadcastDeltas)80 TEST_F(RaggedRangeOpTest, BroadcastDeltas) {
81 BuildRaggedRangeGraph<int>();
82 AddInputFromArray<int>(TensorShape({3}), {0, 5, 8}); // starts
83 AddInputFromArray<int>(TensorShape({3}), {8, 7, 8}); // limits
84 AddInputFromArray<int>(TensorShape({}), {1}); // deltas
85 TF_ASSERT_OK(RunOpKernel());
86
87 // Expected: [[0, 1, 2, 3, 4, 5, 6, 7], [5, 6], []]
88 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
89 test::AsTensor<int64>({0, 8, 10, 10}));
90 test::ExpectTensorEqual<int>(
91 *GetOutput(kValuesOutput),
92 test::AsTensor<int>({0, 1, 2, 3, 4, 5, 6, 7, 5, 6}));
93 }
94
TEST_F(RaggedRangeOpTest,BroadcastLimitsAndDeltas)95 TEST_F(RaggedRangeOpTest, BroadcastLimitsAndDeltas) {
96 BuildRaggedRangeGraph<int>();
97 AddInputFromArray<int>(TensorShape({}), {0}); // starts
98 AddInputFromArray<int>(TensorShape({3}), {3, 0, 2}); // limits
99 AddInputFromArray<int>(TensorShape({}), {1}); // deltas
100 TF_ASSERT_OK(RunOpKernel());
101
102 // Expected: [[0, 1, 2], [], [0, 1]]
103 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
104 test::AsTensor<int64>({0, 3, 3, 5}));
105 test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
106 test::AsTensor<int>({0, 1, 2, 0, 1}));
107 }
108
TEST_F(RaggedRangeOpTest,BroadcastStartsAndLimits)109 TEST_F(RaggedRangeOpTest, BroadcastStartsAndLimits) {
110 BuildRaggedRangeGraph<int>();
111 AddInputFromArray<int>(TensorShape({}), {0}); // starts
112 AddInputFromArray<int>(TensorShape({}), {12}); // limits
113 AddInputFromArray<int>(TensorShape({3}), {3, 4, 5}); // deltas
114 TF_ASSERT_OK(RunOpKernel());
115
116 // Expected: [[0, 3, 6, 9], [0, 4, 8], [0, 5, 10]]]
117 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
118 test::AsTensor<int64>({0, 4, 7, 10}));
119 test::ExpectTensorEqual<int>(
120 *GetOutput(kValuesOutput),
121 test::AsTensor<int>({0, 3, 6, 9, 0, 4, 8, 0, 5, 10}));
122 }
123
TEST_F(RaggedRangeOpTest,AllScalarInputs)124 TEST_F(RaggedRangeOpTest, AllScalarInputs) {
125 BuildRaggedRangeGraph<int>();
126 AddInputFromArray<int>(TensorShape({}), {0}); // starts
127 AddInputFromArray<int>(TensorShape({}), {5}); // limits
128 AddInputFromArray<int>(TensorShape({}), {1}); // deltas
129 TF_ASSERT_OK(RunOpKernel());
130
131 // Expected: [[0, 1, 2, 3, 4]
132 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
133 test::AsTensor<int64>({0, 5}));
134 test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
135 test::AsTensor<int>({0, 1, 2, 3, 4}));
136 }
137
TEST_F(RaggedRangeOpTest,InvalidArgsStarts)138 TEST_F(RaggedRangeOpTest, InvalidArgsStarts) {
139 BuildRaggedRangeGraph<int>();
140 AddInputFromArray<int>(TensorShape({4, 1}), {0, 5, 8, 5}); // starts
141 AddInputFromArray<int>(TensorShape({4}), {8, 7, 8, 1}); // limits
142 AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1}); // deltas
143 EXPECT_EQ("starts must be a scalar or vector", RunOpKernel().error_message());
144 }
145
TEST_F(RaggedRangeOpTest,InvalidArgsLimits)146 TEST_F(RaggedRangeOpTest, InvalidArgsLimits) {
147 BuildRaggedRangeGraph<int>();
148 AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5}); // starts
149 AddInputFromArray<int>(TensorShape({4, 1}), {8, 7, 8, 1}); // limits
150 AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1}); // deltas
151 EXPECT_EQ("limits must be a scalar or vector", RunOpKernel().error_message());
152 }
153
TEST_F(RaggedRangeOpTest,InvalidArgsDeltas)154 TEST_F(RaggedRangeOpTest, InvalidArgsDeltas) {
155 BuildRaggedRangeGraph<int>();
156 AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5}); // starts
157 AddInputFromArray<int>(TensorShape({4}), {8, 7, 8, 1}); // limits
158 AddInputFromArray<int>(TensorShape({4, 1}), {2, 1, 1, -1}); // deltas
159 EXPECT_EQ("deltas must be a scalar or vector", RunOpKernel().error_message());
160 }
161
TEST_F(RaggedRangeOpTest,InvalidArgsShapeMismatch)162 TEST_F(RaggedRangeOpTest, InvalidArgsShapeMismatch) {
163 BuildRaggedRangeGraph<int>();
164 AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5}); // starts
165 AddInputFromArray<int>(TensorShape({3}), {7, 8, 1}); // limits
166 AddInputFromArray<int>(TensorShape({4}), {2, 1, 1, -1}); // deltas
167 EXPECT_EQ("starts, limits, and deltas must have the same shape",
168 RunOpKernel().error_message());
169 }
170
TEST_F(RaggedRangeOpTest,InvalidArgsZeroDelta)171 TEST_F(RaggedRangeOpTest, InvalidArgsZeroDelta) {
172 BuildRaggedRangeGraph<int>();
173 AddInputFromArray<int>(TensorShape({4}), {0, 5, 8, 5}); // starts
174 AddInputFromArray<int>(TensorShape({4}), {7, 8, 8, 1}); // limits
175 AddInputFromArray<int>(TensorShape({4}), {2, 1, 0, -1}); // deltas
176 EXPECT_EQ("Requires delta != 0", RunOpKernel().error_message());
177 }
178
TEST_F(RaggedRangeOpTest,EmptyRangePositiveDelta)179 TEST_F(RaggedRangeOpTest, EmptyRangePositiveDelta) {
180 BuildRaggedRangeGraph<int>();
181 AddInputFromArray<int>(TensorShape({2}), {0, 5}); // starts
182 AddInputFromArray<int>(TensorShape({2}), {5, 0}); // limits
183 AddInputFromArray<int>(TensorShape({}), {2}); // deltas
184 TF_ASSERT_OK(RunOpKernel());
185
186 // Expected: [[0, 2, 4], []]
187 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
188 test::AsTensor<int64>({0, 3, 3}));
189 test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
190 test::AsTensor<int>({0, 2, 4}));
191 }
192
TEST_F(RaggedRangeOpTest,EmptyRangeNegativeDelta)193 TEST_F(RaggedRangeOpTest, EmptyRangeNegativeDelta) {
194 BuildRaggedRangeGraph<int>();
195 AddInputFromArray<int>(TensorShape({2}), {0, 5}); // starts
196 AddInputFromArray<int>(TensorShape({2}), {5, 0}); // limits
197 AddInputFromArray<int>(TensorShape({}), {-2}); // deltas
198 TF_ASSERT_OK(RunOpKernel());
199
200 // Expected: [[], [5, 3, 1]]
201 test::ExpectTensorEqual<int64>(*GetOutput(kSplitsOutput),
202 test::AsTensor<int64>({0, 0, 3}));
203 test::ExpectTensorEqual<int>(*GetOutput(kValuesOutput),
204 test::AsTensor<int>({5, 3, 1}));
205 }
206
TEST_F(RaggedRangeOpTest,ShapeFn)207 TEST_F(RaggedRangeOpTest, ShapeFn) {
208 // RaggedRange(starts, limits, deltas) -> [splits, values]
209 ShapeInferenceTestOp op("RaggedRange");
210 INFER_OK(op, "?;?;?", "[?];[?]");
211 INFER_OK(op, "[3];[3];[3]", "[4];[?]");
212 INFER_OK(op, "[3];[3];[]", "[4];[?]"); // broadcast deltas
213 INFER_OK(op, "[3];[];[3]", "[4];[?]"); // broadcast limits
214 INFER_OK(op, "[];[3];[3]", "[4];[?]"); // broadcast starts
215 INFER_OK(op, "[];[];[]", "[2];[?]"); // degenerate case: all scalar inputs
216 INFER_ERROR("Shape must be at most rank 1 but is rank 2", op,
217 "[5,5];[5];[5]");
218 INFER_ERROR("Shape must be at most rank 1 but is rank 2", op,
219 "[5];[5,5];[5]");
220 INFER_ERROR("Shape must be at most rank 1 but is rank 2", op,
221 "[5];[5];[5,5]");
222 INFER_ERROR("Dimensions must be equal, but are 4 and 3", op, "[3];[4];[3]");
223 }
224
225 } // namespace
226 } // namespace tensorflow
227