1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_builder.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference_testutil.h"
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_shape.pb.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/lib/core/status_test_util.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/test.h"
25
26 namespace tensorflow {
27
TEST(MathOpsTest,AddN_ShapeFn)28 TEST(MathOpsTest, AddN_ShapeFn) {
29 ShapeInferenceTestOp op("AddN");
30 auto set_n = [&op](int n) {
31 std::vector<NodeDefBuilder::NodeOut> src_list;
32 src_list.reserve(n);
33 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
34 TF_ASSERT_OK(NodeDefBuilder("test", "AddN")
35 .Input(src_list)
36 .Attr("N", n)
37 .Finalize(&op.node_def));
38 };
39
40 set_n(2);
41 // Adding two unknowns returns either input.
42 INFER_OK(op, "?;?", "in0|in1");
43
44 // known+unknown returns the known input.
45 INFER_OK(op, "[1];[?]", "in0");
46 INFER_OK(op, "[1];?", "in0");
47 INFER_OK(op, "[?];[1]", "in1");
48 INFER_OK(op, "?;[1]", "in1");
49
50 set_n(2);
51 INFER_OK(op, "[1,2];[?,2]", "in0");
52 INFER_OK(op, "[1,2];[1,2]", "in0|in1");
53 INFER_OK(op, "[?,2];[1,2]", "in1");
54
55 set_n(3);
56 INFER_OK(op, "[1,?];[?,2];[1,2]", "in2");
57 INFER_OK(op, "[1,2];[?,2];[1,?]", "in0");
58 INFER_OK(op, "?;?;[1,2]", "in2");
59
60 set_n(2);
61 INFER_OK(op, "?;[1,2]", "in1");
62 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
63 INFER_OK(op, "[?,2,?];[?,?,3]", "[d0_0|d1_0,d0_1,d1_2]");
64 INFER_OK(op, "[?,2];[1,?]", "[d1_0,d0_1]");
65
66 set_n(3);
67 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 2 and 4", op,
68 "[1,2];?;[1,4]");
69 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2];?;[1,4]");
70 set_n(4);
71 INFER_ERROR("Shapes must be equal rank, but are 2 and 3", op,
72 "?;[1,2];?;[1,2,3]");
73 INFER_ERROR("From merging shape 1 with other shapes.", op,
74 "?;[1,2];?;[1,2,3]");
75 }
76
TEST(MathOpsTest,UnchangedShape_ShapeFn)77 TEST(MathOpsTest, UnchangedShape_ShapeFn) {
78 ShapeInferenceTestOp op("Cast");
79 INFER_OK(op, "?", "in0");
80 INFER_OK(op, "[?]", "in0");
81 INFER_OK(op, "[1,?,3,4]", "in0");
82 }
83
TEST(MathOpsTest,Segment_ShapeFn)84 TEST(MathOpsTest, Segment_ShapeFn) {
85 // Tests SegmentReductionShapeFn.
86 for (const auto* op_name : {"SegmentMax", "SegmentMean", "SegmentMin",
87 "SegmentProd", "SegmentSum"}) {
88 ShapeInferenceTestOp op(op_name);
89 INFER_OK(op, "?;?", "?");
90 INFER_OK(op, "?;[100]", "?");
91
92 // Data shape with single dimension.
93 INFER_OK(op, "[?];?", "[?]");
94 INFER_OK(op, "[?];[100]", "[?]");
95 INFER_OK(op, "[1];?", "[?]");
96 INFER_OK(op, "[1];[100]", "[?]");
97
98 // Data shape with multiple dimensions.
99 INFER_OK(op, "[?,?];?", "[?,d0_1]");
100 INFER_OK(op, "[?,2];[100]", "[?,d0_1]");
101 INFER_OK(op, "[?,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
102 INFER_OK(op, "[1,?];?", "[?,d0_1]");
103 INFER_OK(op, "[1,2];[100]", "[?,d0_1]");
104 INFER_OK(op, "[1,2,?,4];[100]", "[?,d0_1,d0_2,d0_3]");
105
106 // Error cases.
107 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
108 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1]");
109 }
110 }
111
TEST(MathOpsTest,BroadcastBinaryOps_ShapeFn)112 TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
113 for (const auto* op_name : {"Add", "Complex",
114 "Div", "Equal",
115 "Greater", "GreaterEqual",
116 "Igamma", "Igammac",
117 "Zeta", "Polygamma",
118 "Less", "LessEqual",
119 "LogicalAnd", "LogicalOr",
120 "Maximum", "Minimum",
121 "Mod", "Mul",
122 "NotEqual", "Pow",
123 "Sub", "SquaredDifference",
124 "DivNoNan"}) {
125 ShapeInferenceTestOp op(op_name);
126 INFER_OK(op, "?;?", "?");
127 INFER_OK(op, "[1,2];?", "?");
128 INFER_OK(op, "?;[1,2]", "?");
129
130 INFER_OK(op, "[?];[1]", "[d0_0]");
131 INFER_OK(op, "[1];[?]", "[d1_0]");
132 INFER_OK(op, "[?];[2]", "[d1_0]");
133 INFER_OK(op, "[2];[?]", "[d0_0]");
134 INFER_OK(op, "[?];[?]", "[?]");
135 INFER_OK(op, "[];[?]", "[d1_0]");
136 INFER_OK(op, "[?];[]", "[d0_0]");
137
138 INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
139 INFER_OK(op, "[];[1]", "[d1_0]");
140 INFER_OK(op, "[1];[]", "[d0_0]");
141
142 INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
143 INFER_OK(op, "[];[2]", "[d1_0]");
144 INFER_OK(op, "[1];[2]", "[d1_0]");
145 INFER_OK(op, "[2];[1]", "[d0_0]");
146 INFER_OK(op, "[2];[]", "[d0_0]");
147 INFER_OK(op, "[2];[?]", "[d0_0]");
148
149 INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
150 INFER_OK(op, "[];[0]", "[d1_0]");
151 INFER_OK(op, "[1];[0]", "[d1_0]");
152 INFER_OK(op, "[0];[1]", "[d0_0]");
153 INFER_OK(op, "[0];[]", "[d0_0]");
154
155 INFER_OK(op, "[2];[?,?]", "[d1_0,d0_0]");
156 INFER_OK(op, "[2,2];[?,?,?]", "[d1_0,d0_0,d0_1]");
157
158 // Multiple dimension cases (same test cases, switching x and y).
159 INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
160 "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
161 INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
162 "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
163 }
164 }
165
TEST(MathOpsTest,Select_ShapeFn)166 TEST(MathOpsTest, Select_ShapeFn) {
167 ShapeInferenceTestOp op("Select");
168 INFER_OK(op, "?;?;?", "in1|in2");
169
170 // scalar case
171 INFER_OK(op, "[];[1];?", "in1");
172 INFER_OK(op, "[];?;?", "in1|in2");
173
174 INFER_OK(op, "[1];?;?",
175 "in1|in2"); // When cond is vector, t/e may not match it.
176 INFER_OK(op, "[1,2];?;?", "in1|in2?");
177
178 INFER_OK(op, "?;[];?", "in1");
179 INFER_OK(op, "?;?;[]", "in2");
180 INFER_OK(op, "?;[1];?", "in1");
181 INFER_OK(op, "?;?;[1]", "in2");
182 INFER_OK(op, "?;[1,2];?", "in1");
183 INFER_OK(op, "?;?;[1,2]", "in2");
184
185 INFER_ERROR("Shapes must be equal rank, but are 0 and 1", op, "[1];[];?");
186 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[];[1];[1,2]");
187 INFER_ERROR("Shapes must be equal rank, but are 1 and 2", op, "[1,2];[1];?");
188 INFER_OK(op, "[2];[?];[?]", "in1|in2");
189
190 INFER_OK(op, "[?];[?,?,3];[1,2,?]", "[d2_0,d2_1,d1_2]");
191 INFER_OK(op, "[2];[?,?,3];[?,2,?]", "[d1_0|d2_0,d2_1,d1_2]");
192 INFER_ERROR("must be equal", op, "[1];[2,?,3];[?,2,?]");
193 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
194 "[2,?];[?,?,3];[?,2,?]");
195 INFER_OK(op, "[2,?,?];[?,?,3];[?,2,?]", "[d0_0,d2_1,d1_2]");
196 INFER_ERROR("Dimension 2 in both shapes must be equal, but are 3 and 5", op,
197 "[2,?,5];[?,?,3];[?,2,?]");
198
199 // Test that handles were merged.
200 //
201 // Tests below will modify handle_data and call run_inference_for_handles to
202 // rerun shape inference, updating the context <c>.
203 const OpRegistrationData* op_reg_data;
204 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
205 typedef std::vector<std::pair<TensorShapeProto, DataType>> ShapeDtypeV;
206 std::vector<std::unique_ptr<ShapeDtypeV>> handle_data;
207 std::unique_ptr<shape_inference::InferenceContext> c;
208 auto run_inference_for_handles = [&]() -> Status {
209 CHECK(op_reg_data->shape_inference_fn != nullptr);
210 c.reset(new shape_inference::InferenceContext(
211 TF_GRAPH_DEF_VERSION, &op.node_def, op_reg_data->op_def,
212 {TensorShapeProto(), TensorShapeProto(), TensorShapeProto()}, {}, {},
213 handle_data));
214 TF_CHECK_OK(c->construction_status());
215 Status s = c->Run(op_reg_data->shape_inference_fn);
216 LOG(INFO) << "Inference got " << s;
217 return s;
218 };
219 auto shape_proto = [](std::initializer_list<int64> dim_sizes) {
220 TensorShapeProto p;
221 for (auto i : dim_sizes) p.add_dim()->set_size(i);
222 return p;
223 };
224
225 TensorShapeProto i0 = shape_proto({1, -1});
226 TensorShapeProto i1 = shape_proto({-1, 2});
227 TensorShapeProto unknown_shape;
228 unknown_shape.set_unknown_rank(true);
229 TensorShapeProto scalar;
230
231 handle_data.emplace_back(
232 new ShapeDtypeV{{scalar, DT_FLOAT}, {unknown_shape, DT_INT32}});
233 handle_data.emplace_back(new ShapeDtypeV{{i0, DT_FLOAT}, {i1, DT_INT32}});
234 handle_data.emplace_back(
235 new ShapeDtypeV{{i1, DT_FLOAT}, {unknown_shape, DT_INT32}});
236
237 TF_ASSERT_OK(run_inference_for_handles());
238 auto* out = c->output_handle_shapes_and_types(0);
239 ASSERT_EQ(2, out->size());
240 EXPECT_EQ("[1,2]", c->DebugString(out->at(0).shape));
241 EXPECT_EQ(DT_FLOAT, out->at(0).dtype);
242 EXPECT_EQ("[?,2]", c->DebugString(out->at(1).shape));
243 EXPECT_EQ(DT_INT32, out->at(1).dtype);
244
245 // Expect an error when the shapes can't be merged.
246 handle_data[2]->at(0).first = shape_proto({2, 2});
247 EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(),
248 "must be equal, but are 1 and 2"));
249 handle_data[2]->at(0).first = i1; // restore to valid
250
251 // Expect an error when the types can't be merged.
252 handle_data[2]->at(1).second = DT_INT64;
253 EXPECT_TRUE(str_util::StrContains(run_inference_for_handles().error_message(),
254 "pointing to different dtypes"));
255 handle_data[2]->at(1).second = DT_INT32; // restore to valid
256
257 // Expect an error when different numbers of tensors are merged.
258 handle_data[2]->push_back({i1, DT_FLOAT});
259 EXPECT_TRUE(
260 str_util::StrContains(run_inference_for_handles().error_message(),
261 "pointing to different numbers of tensors"));
262 handle_data[2]->pop_back(); // restore to valid.
263 }
264
TEST(MathOpsTest,Range_ShapeFn)265 TEST(MathOpsTest, Range_ShapeFn) {
266 ShapeInferenceTestOp op("Range");
267
268 TF_ASSERT_OK(NodeDefBuilder("test", "Range")
269 .Input({"start", {}, DT_INT32})
270 .Input({"limit", {}, DT_INT32})
271 .Input({"delta", {}, DT_INT32})
272 .Attr("Tidx", DT_INT32)
273 .Finalize(&op.node_def));
274
275 op.input_tensors.resize(3);
276 INFER_OK(op, "?;?;?", "[?]");
277 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
278 INFER_ERROR("for 'start'", op, "[1,2];?;?");
279
280 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
281 INFER_ERROR("for 'limit'", op, "?;[1,2];?");
282
283 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
284 INFER_ERROR("for 'delta'", op, "?;?;[1,2]");
285
286 Tensor start_t = test::AsScalar(1);
287 op.input_tensors[0] = &start_t;
288 INFER_OK(op, "?;?;?", "[?]");
289 Tensor limit_t = test::AsScalar(1);
290 op.input_tensors[1] = &limit_t;
291 INFER_OK(op, "?;?;?", "[?]");
292
293 Tensor delta_t = test::AsScalar(1);
294 op.input_tensors[2] = &delta_t;
295 INFER_OK(op, "?;?;?", "[0]");
296
297 delta_t = test::AsScalar(0);
298 INFER_ERROR("Requires delta != 0", op, "?;?;?");
299 delta_t = test::AsScalar(3);
300
301 limit_t = test::AsScalar(-1);
302 INFER_ERROR("Requires start <= limit when delta > 0: 1/-1", op, "?;?;?");
303
304 delta_t = test::AsScalar(-1);
305 INFER_OK(op, "?;?;?", "[2]");
306
307 limit_t = test::AsScalar(4);
308 INFER_ERROR("Requires start >= limit when delta < 0: 1/4", op, "?;?;?");
309
310 limit_t = test::AsScalar(100);
311 start_t = test::AsScalar(2);
312 delta_t = test::AsScalar(3);
313 INFER_OK(op, "?;?;?", "[33]");
314 }
315
TEST(MathOpsTest,LinSpace_ShapeFn)316 TEST(MathOpsTest, LinSpace_ShapeFn) {
317 ShapeInferenceTestOp op("LinSpace");
318 op.input_tensors.resize(3);
319 INFER_OK(op, "?;?;?", "[?]");
320 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "[1,2];?;?");
321 INFER_ERROR("for 'start'", op, "[1,2];?;?");
322 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;[1,2];?");
323 INFER_ERROR("for 'stop'", op, "?;[1,2];?");
324 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
325 INFER_ERROR("for 'num'", op, "?;?;[1,2]");
326
327 Tensor num_t = test::AsScalar(1);
328 op.input_tensors[2] = &num_t;
329 INFER_OK(op, "?;?;?", "[1]");
330 num_t = test::AsScalar(2);
331 INFER_OK(op, "?;?;?", "[2]");
332 num_t = test::AsScalar(-1);
333 INFER_ERROR("Requires num > 0: -1", op, "?;?;?");
334 }
335
TEST(MathOpsTest,UnsortedSegmentSum_ShapeFn)336 TEST(MathOpsTest, UnsortedSegmentSum_ShapeFn) {
337 ShapeInferenceTestOp op("UnsortedSegmentSum");
338 op.input_tensors.resize(3);
339 INFER_OK(op, "?;?;?", "?");
340 INFER_OK(op, "?;[?];?", "?");
341 INFER_ERROR("Shape must be rank 0 but is rank 2", op, "?;?;[1,2]");
342 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op,
343 "[1,?,2];[1,?,3];?");
344 INFER_OK(op, "?;[3];?", "?");
345 INFER_ERROR("Shape must be at least rank 3 but is rank 2", op,
346 "[1,2];[1,2,3];?");
347
348 Tensor num_segments_t = test::AsScalar(100);
349 op.input_tensors[2] = &num_segments_t;
350 INFER_OK(op, "[?,2,3,?,5];[1,2,?];[]", "[100,d0_3,d0_4]");
351
352 num_segments_t = test::AsScalar(-1);
353 INFER_ERROR(("Dimension size, given by scalar input 2, must be "
354 "non-negative but is -1"),
355 op, "[3];[3];?");
356 }
357
TEST(MathOpsTest,SparseSegment_ShapeFn)358 TEST(MathOpsTest, SparseSegment_ShapeFn) {
359 ShapeInferenceTestOp op("SparseSegmentSum");
360 op.input_tensors.resize(3);
361 INFER_OK(op, "?;?;?", "?");
362 INFER_OK(op, "[2,4,3];[3];[3]", "[?,d0_1,d0_2]");
363
364 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2,4,3];[];[3]");
365 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,4,3];[3];[3,4]");
366
367 INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 4", op,
368 "[2,4,3];[3];[4]");
369 }
370
TEST(MathOpsTest,SparseSegmentGrad_ShapeFn)371 TEST(MathOpsTest, SparseSegmentGrad_ShapeFn) {
372 ShapeInferenceTestOp op("SparseSegmentMeanGrad");
373 op.input_tensors.resize(4);
374 INFER_OK(op, "?;?;?;?", "?");
375 INFER_OK(op, "[2,4,3];[3];[3];[]", "[?,d0_1,d0_2]");
376
377 Tensor num_segments_t = test::AsScalar(100);
378 op.input_tensors[3] = &num_segments_t;
379 INFER_OK(op, "[2,4,3];[3];[3];[]", "[100,d0_1,d0_2]");
380
381 INFER_ERROR("Shape must be rank 0 but is rank 2", op,
382 "[2,4,3];[3];[3];[1,1]");
383
384 // Negative value is not allowed
385 num_segments_t = test::AsScalar(-100);
386 op.input_tensors[3] = &num_segments_t;
387 INFER_ERROR("Cannot specify a negative value", op, "[2,4,3];[3];[3];[]");
388 }
389
TEST(MathOpsTest,BatchMatMul_ShapeFn)390 TEST(MathOpsTest, BatchMatMul_ShapeFn) {
391 ShapeInferenceTestOp op("BatchMatMul");
392 auto set_adj = [&op](bool adj_x, bool adj_y) {
393 TF_ASSERT_OK(NodeDefBuilder("test", "BatchMatMul")
394 .Input({"a", 0, DT_FLOAT})
395 .Input({"b", 0, DT_FLOAT})
396 .Attr("adj_x", adj_x)
397 .Attr("adj_y", adj_y)
398 .Finalize(&op.node_def));
399 };
400
401 set_adj(false, false);
402
403 // Rank checks.
404 INFER_ERROR("at least rank 2", op, "[1];?");
405 INFER_ERROR("at least rank 2", op, "?;[2]");
406
407 INFER_OK(op, "?;?", "?");
408
409 // 0 batch dims.
410 INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
411
412 // 2 batch dims.
413 INFER_OK(op, "[?,?,?,?];?", "[d0_0,d0_1,d0_2,?]");
414
415 // Test adj_a, testing output and that inner dims are compared.
416 set_adj(false, false);
417 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_2,d1_3]");
418 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,3,1]"); // inner dim mismatch
419 set_adj(true, false);
420 INFER_OK(op, "[1,2,3,4];[1,2,?,?]", "[d0_0,d0_1,d0_3,d1_3]");
421 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,3,1]"); // inner dim mismatch
422
423 // Test adj_b=true.
424 set_adj(false, true);
425 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_2,d1_2]");
426 INFER_ERROR("are 2 and 3", op, "[?,1,2];[?,1,3]"); // inner dim mismatch
427 set_adj(true, true);
428 INFER_OK(op, "[1,2,?,?];[1,2,3,4]", "[d0_0,d0_1,d0_3,d1_2]");
429 INFER_ERROR("are 2 and 3", op, "[?,2,1];[?,1,3]"); // inner dim mismatch
430 }
431
TEST(MathOpsTest,ArgOps_ShapeFn)432 TEST(MathOpsTest, ArgOps_ShapeFn) {
433 ShapeInferenceTestOp op("ArgMax");
434 op.input_tensors.resize(2);
435
436 INFER_OK(op, "?;?", "?");
437
438 // input rank <= 1 produces scalar
439 INFER_OK(op, "[2];?", "[]");
440 INFER_OK(op, "[];?", "[]");
441
442 // Incorrect rank for dimension
443 INFER_ERROR("must be rank 0", op, "[2];[1]");
444
445 // dimension not available, but input rank is. Output is unknown
446 // shape with rank one less than input rank.
447 INFER_OK(op, "[2,3,4];?", "[?,?]");
448 INFER_OK(op, "[2,3,4,5,6];?", "[?,?,?,?]");
449
450 // Dimension values known
451 Tensor dimension = test::AsScalar(0);
452 op.input_tensors[1] = &dimension;
453 INFER_OK(op, "[2,3,4];[]", "[d0_1,d0_2]");
454
455 dimension = test::AsScalar(1);
456 op.input_tensors[1] = &dimension;
457 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_2]");
458
459 dimension = test::AsScalar(2);
460 op.input_tensors[1] = &dimension;
461 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
462
463 // Dimension value out of bounds
464 dimension = test::AsScalar(10);
465 op.input_tensors[1] = &dimension;
466 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
467
468 dimension = test::AsScalar(-10);
469 op.input_tensors[1] = &dimension;
470 INFER_ERROR("must be in the range [-3, 3)", op, "[2,3,4];[]");
471
472 dimension = test::AsScalar(-1);
473 op.input_tensors[1] = &dimension;
474 INFER_OK(op, "[2,3,4];[]", "[d0_0,d0_1]");
475 }
476
TEST(MathOpsTest,Betainc_ShapeFn)477 TEST(MathOpsTest, Betainc_ShapeFn) {
478 ShapeInferenceTestOp op("Betainc");
479
480 INFER_OK(op, "?;?;?", "?");
481 INFER_OK(op, "[?,?];?;?", "in0");
482 INFER_OK(op, "[?,2];?;[1,?]", "[d2_0,d0_1]");
483 INFER_OK(op, "[?,2,?];[1,?,?];[?,?,3]", "[d1_0,d0_1,d2_2]");
484
485 INFER_OK(op, "[?,2,?];[];[?,?,3]", "[d0_0|d2_0,d0_1,d2_2]");
486 INFER_OK(op, "[];[];[?,?,3]", "in2");
487
488 // All but one is a scalar, so use it.
489 INFER_OK(op, "[];[];?", "in2");
490 INFER_OK(op, "[];[];[1,2,3,4]", "in2");
491
492 // All scalar input; implementation picks in0.
493 INFER_OK(op, "[];[];[]", "in0");
494
495 // Non-scalars must match shape.
496 INFER_ERROR("must be equal", op, "[1,2];[];[1,4]");
497 INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]");
498 }
499
TEST(MathOpsTest,Requantize_ShapeFn)500 TEST(MathOpsTest, Requantize_ShapeFn) {
501 ShapeInferenceTestOp op("Requantize");
502
503 INFER_OK(op, "?;?;?;?;?", "in0;[];[]");
504 INFER_OK(op, "?;[];[];[];[]", "in0;[];[]");
505
506 // Rank checks on input scalars.
507 INFER_ERROR("must be rank 0", op, "?;[1];?;?;?");
508 INFER_ERROR("must be rank 0", op, "?;?;[2];?;?");
509 INFER_ERROR("must be rank 0", op, "?;?;?;[3];?");
510 INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]");
511 }
512
TEST(MathOpstest,RequantizationRange_ShapeFn)513 TEST(MathOpstest, RequantizationRange_ShapeFn) {
514 ShapeInferenceTestOp op("RequantizationRange");
515
516 INFER_OK(op, "?;?;?", "[];[]");
517 INFER_OK(op, "?;[];[]", "[];[]");
518
519 // Rank checks on input scalars.
520 INFER_ERROR("must be rank 0", op, "?;[1];?");
521 INFER_ERROR("must be rank 0", op, "?;?;[2]");
522 }
523
TEST(MathOpsTest,Cross_ShapeFn)524 TEST(MathOpsTest, Cross_ShapeFn) {
525 ShapeInferenceTestOp op("Cross");
526
527 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
528 INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
529 INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
530
531 INFER_OK(op, "?;?", "in0");
532 INFER_OK(op, "[?];[?]", "in0");
533 INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
534 }
535
TEST(MathOpsTest,HistogramFixedWidth_ShapeFn)536 TEST(MathOpsTest, HistogramFixedWidth_ShapeFn) {
537 ShapeInferenceTestOp op("HistogramFixedWidth");
538
539 // value_range should be vector.
540 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];[];[]");
541 // value_range should have 2 elements.
542 INFER_ERROR("Dimension must be 2 but is 3", op, "[];[3];[]");
543 // nbins should be scalar.
544 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[];[2];[2]");
545
546 INFER_OK(op, "?;?;?", "[?]");
547 INFER_OK(op, "[?];[2];[]", "[?]");
548 INFER_OK(op, "[?];[2];?", "[?]");
549 }
550
TEST(MathOpsTest,QuantizedAdd_ShapeFn)551 TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
552 ShapeInferenceTestOp op("QuantizedAdd");
553
554 INFER_OK(op, "?;?;?;?;?;?", "?;[];[]");
555 INFER_OK(op, "?;?;[];[];[];[]", "?;[];[]");
556 INFER_OK(op, "[1,2];?;[];[];[];[]", "?;[];[]");
557 INFER_OK(op, "[];[2];[];[];[];[]", "[d1_0];[];[]");
558
559 // Rank checks on input scalars.
560 INFER_ERROR("must be rank 0", op, "?;?;[1];?;?;?");
561 INFER_ERROR("must be rank 0", op, "?;?;?;[2];?;?");
562 INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
563 INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
564 }
565
TEST(MathOpsTest,Bincount_ShapeFn)566 TEST(MathOpsTest, Bincount_ShapeFn) {
567 ShapeInferenceTestOp op("Bincount");
568
569 // size should be scalar.
570 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
571
572 INFER_OK(op, "?;?;?", "[?]");
573 INFER_OK(op, "?;[];?", "[?]");
574 INFER_OK(op, "[?];[];?", "[?]");
575 INFER_OK(op, "[?];[];[?]", "[?]");
576 }
577 } // end namespace tensorflow
578