1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_builder.h"
17 #include "tensorflow/core/framework/node_def_util.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/framework/shape_inference_testutil.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/public/version.h"
27
28 namespace tensorflow {
29
TEST(ArrayOpsTest,TensorScatterUpdate_ShapeFn)30 TEST(ArrayOpsTest, TensorScatterUpdate_ShapeFn) {
31 ShapeInferenceTestOp op("TensorScatterUpdate");
32
33 INFER_OK(op, "[4,3];[8,2];[8]", "in0");
34 INFER_OK(op, "[?,?];[?,2];[?]", "in0");
35 INFER_OK(op, "[?];[?];[?]", "in0");
36
37 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
38 "[];[?,2];[?]");
39 INFER_ERROR("Indices and updates specified for empty input", op,
40 "[0,2,2];[8,2];[8]");
41 INFER_ERROR(
42 "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
43 "dimensions [0,1) of updates[shape=[9]] = [9]",
44 op, "[?,?];[8,2];[9]");
45 INFER_ERROR(
46 "Dimensions [2,2) of input[shape=[?,?]] = [] must match "
47 "dimensions [1,2) of updates[shape=[?,1]] = [1]",
48 op, "[?,?];[?,2];[?,1]");
49 }
50
TEST(ArrayOpsTest,ScatterNd_ShapeFn)51 TEST(ArrayOpsTest, ScatterNd_ShapeFn) {
52 ShapeInferenceTestOp op("ScatterNd");
53
54 INFER_OK(op, "[8,2];[8];[2]", "[?,?]");
55
56 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[?,2];[?];[]");
57 INFER_ERROR(
58 "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
59 "dimensions [0,1) of updates[shape=[9]] = [9]",
60 op, "[8,2];[9];[?]");
61 }
62
TEST(ArrayOpsTest,UnravelIndex_ShapeFn)63 TEST(ArrayOpsTest, UnravelIndex_ShapeFn) {
64 ShapeInferenceTestOp op("UnravelIndex");
65
66 INFER_OK(op, "?;?", "?");
67
68 INFER_OK(op, "[];[?]", "[d1_0]");
69
70 INFER_OK(op, "[4,5];[?]", "[d1_0,20]");
71 INFER_OK(op, "[2,3,4];[?]", "[d1_0,24]");
72 INFER_OK(op, "?;[?]", "?");
73 INFER_OK(op, "[?];[?]", "[d1_0,?]");
74
75 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,1]");
76 }
77
TEST(ArrayOpsTest,Pack_ShapeFn)78 TEST(ArrayOpsTest, Pack_ShapeFn) {
79 ShapeInferenceTestOp op("Pack");
80 auto set_axis = [&op](int axis) {
81 int n = 3;
82 std::vector<NodeDefBuilder::NodeOut> src_list;
83 src_list.reserve(n);
84 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
85 TF_ASSERT_OK(NodeDefBuilder("test", "Pack")
86 .Input(src_list)
87 .Attr("N", n)
88 .Attr("axis", axis)
89 .Finalize(&op.node_def));
90 };
91
92 set_axis(0);
93 INFER_OK(op, "?;?;?", "?");
94
95 for (int axis : {0, -3}) {
96 set_axis(axis);
97 INFER_OK(op, "?;?;?", "?");
98 INFER_OK(op, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
99 INFER_OK(op, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
100 INFER_OK(op, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
101 }
102 for (int axis : {1, -2}) {
103 set_axis(axis);
104 INFER_OK(op, "?;?;?", "?");
105 INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
106 INFER_OK(op, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
107 INFER_OK(op, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
108 }
109 for (int axis : {2, -1}) {
110 set_axis(axis);
111 INFER_OK(op, "?;?;?", "?");
112 INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
113 INFER_OK(op, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
114 INFER_OK(op, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
115 }
116
117 set_axis(-4);
118 INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,3];[1,3];?");
119 set_axis(3);
120 INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,3];[1,3];?");
121
122 set_axis(0);
123
124 // Check that both components of error message are there.
125 INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
126 "[1,2,3];?;[1,4]");
127 INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2,3];?;[1,4]");
128 }
129
TEST(ArrayOpsTest,UnPack_ShapeFn)130 TEST(ArrayOpsTest, UnPack_ShapeFn) {
131 ShapeInferenceTestOp op("Unpack");
132 auto set_axis_and_num = [&op](int axis, int num) {
133 TF_ASSERT_OK(NodeDefBuilder("test", "Unpack")
134 .Input("a", 0, DT_FLOAT)
135 .Attr("axis", axis)
136 .Attr("num", num)
137 .Finalize(&op.node_def));
138 };
139
140 set_axis_and_num(0, 1);
141 INFER_OK(op, "?", "?");
142
143 for (int axis : {0, -3}) {
144 set_axis_and_num(axis, 1);
145 INFER_OK(op, "?", "?");
146 INFER_OK(op, "[1,2,3]", "[d0_1,d0_2]");
147 INFER_OK(op, "[?,?,?]", "[d0_1,d0_2]");
148 }
149 for (int axis : {1, -2}) {
150 set_axis_and_num(axis, 2);
151 INFER_OK(op, "[1,2,3]", "[d0_0,d0_2];[d0_0,d0_2]");
152 INFER_OK(op, "[?,?,?]", "[d0_0,d0_2];[d0_0,d0_2]");
153 }
154 for (int axis : {2, -1}) {
155 set_axis_and_num(axis, 3);
156 INFER_OK(op, "[1,2,3]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
157 INFER_OK(op, "[?,?,?]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
158 }
159
160 set_axis_and_num(2, 2);
161 INFER_ERROR("Dimension must be 2 but is 3", op, "[1,2,3]");
162
163 set_axis_and_num(-4, 3);
164 INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,2,3]");
165 set_axis_and_num(3, 3);
166 INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,2,3]");
167 }
168
TEST(ArrayOpsTest,Const_ShapeFn)169 TEST(ArrayOpsTest, Const_ShapeFn) {
170 ShapeInferenceTestOp op("Const");
171 TensorProto tensor_proto;
172 auto* shape_proto = tensor_proto.mutable_tensor_shape();
173 auto rebuild_node_def = [&op, &tensor_proto]() {
174 TF_ASSERT_OK(NodeDefBuilder("test", "Const")
175 .Attr("value", tensor_proto)
176 .Finalize(&op.node_def));
177 };
178
179 TensorShape{}.AsProto(shape_proto);
180 rebuild_node_def();
181 INFER_OK(op, "", "[]");
182 TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
183 rebuild_node_def();
184 INFER_OK(op, "", "[1,2,3,4]");
185
186 shape_proto->add_dim()->set_size(-1);
187 rebuild_node_def();
188 INFER_ERROR("Shape [1,2,3,4,?] is not fully defined", op, "");
189 }
190
TEST(ArrayOpsTest,UnchangedShapes_ShapeFn)191 TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
192 for (const char* op_name : {
193 "CheckNumerics",
194 "Identity",
195 "RefIdentity",
196 "QuantizeAndDequantize",
197 "StopGradient",
198 "ZerosLike",
199 "OnesLike",
200 }) {
201 ShapeInferenceTestOp op(op_name);
202 INFER_OK(op, "?", "in0");
203 INFER_OK(op, "[]", "in0");
204 INFER_OK(op, "[1,2,?,4,5]", "in0");
205 }
206
207 // inputs 1 and 2 are ignored; input 0 is transferred to output 0.
208 ShapeInferenceTestOp op("MatrixBandPart");
209 INFER_OK(op, "?;?;?", "in0");
210 INFER_OK(op, "[];?;?", "in0");
211 INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
212 }
213
TEST(ArrayOpsTest,GuaranteeConst_ShapeFn)214 TEST(ArrayOpsTest, GuaranteeConst_ShapeFn) {
215 ShapeInferenceTestOp op("GuaranteeConst");
216 INFER_OK(op, "?", "in0");
217 INFER_OK(op, "[]", "in0");
218 INFER_OK(op, "[1,2,?,4,5]", "in0");
219 }
220
TEST(ArrayOpsTest,Identity_ShapeFnHandles)221 TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
222 const char* op_name = "Identity";
223 ShapeInferenceTestOp op(op_name);
224 // Check that handle dtypes are preserved.
225 const OpRegistrationData* op_reg_data;
226 TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
227 std::vector<
228 std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>
229 handle_data;
230 handle_data.emplace_back(
231 new std::vector<std::pair<PartialTensorShape, DataType>>(
232 {{PartialTensorShape(), DT_BOOL}}));
233 shape_inference::InferenceContext c(
234 TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
235 {PartialTensorShape()}, {}, {}, handle_data);
236 TF_ASSERT_OK(c.construction_status());
237 ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
238 TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
239
240 const auto* shapes_and_types = c.output_handle_shapes_and_types(0);
241 ASSERT_TRUE(shapes_and_types != nullptr);
242 ASSERT_EQ(1, shapes_and_types->size());
243 EXPECT_EQ((*shapes_and_types)[0].dtype, DT_BOOL);
244 }
245
TEST(ArrayOpsTest,Diag_ShapeFn)246 TEST(ArrayOpsTest, Diag_ShapeFn) {
247 ShapeInferenceTestOp op("Diag");
248 INFER_OK(op, "?", "?");
249 INFER_OK(op, "[1,?,3]", "[d0_0,d0_1,d0_2,d0_0,d0_1,d0_2]");
250 INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2,d0_3,d0_0,d0_1,d0_2,d0_3]");
251 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
252 }
253
TEST(ArrayOpsTest,DiagPart_ShapeFn)254 TEST(ArrayOpsTest, DiagPart_ShapeFn) {
255 ShapeInferenceTestOp op("DiagPart");
256 INFER_OK(op, "?", "?");
257 INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_3]");
258 INFER_OK(op, "[1,?,3,?,4,3]", "[d0_0,d0_4,d0_2|d0_5]");
259 INFER_OK(op, "[1,2,3,?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_7]");
260 INFER_ERROR("Input must have even and non-zero rank", op, "[]");
261 INFER_ERROR("Input must have even and non-zero rank", op, "[?]");
262 INFER_ERROR("Input must have even and non-zero rank", op, "[1,2,3]");
263 INFER_ERROR("Dimensions must be equal, but are 2 and 10", op, "[1,2,?,10]");
264 }
265
TEST(ArrayOpsTest,MatrixDiag_ShapeFn)266 TEST(ArrayOpsTest, MatrixDiag_ShapeFn) {
267 ShapeInferenceTestOp op("MatrixDiag");
268 INFER_OK(op, "?", "?");
269 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
270 INFER_OK(op, "[?]", "[d0_0,d0_0]");
271 INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,d0_3]");
272 }
273
TEST(ArrayOpsTest,MatrixDiagPart_ShapeFn)274 TEST(ArrayOpsTest, MatrixDiagPart_ShapeFn) {
275 ShapeInferenceTestOp op("MatrixDiagPart");
276 INFER_OK(op, "?", "?");
277 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?]");
278 INFER_OK(op, "[?,1,2,2]", "[d0_0,d0_1,d0_2|d0_3]");
279 INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2]");
280 INFER_OK(op, "[?,1,3,2]", "[d0_0,d0_1,d0_3]");
281 }
282
TEST(ArrayOpsTest,Reverse_ShapeFn)283 TEST(ArrayOpsTest, Reverse_ShapeFn) {
284 ShapeInferenceTestOp op("Reverse");
285 INFER_OK(op, "?;?", "in0");
286 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
287 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
288 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4]");
289 INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
290 op, "[1,2,3,4,5,6,7,8,9];[9]");
291 INFER_OK(op, "[1,2,3,?];[4]", "in0");
292 INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
293 }
294
TEST(ArrayOpsTest,ReverseV2_ShapeFn)295 TEST(ArrayOpsTest, ReverseV2_ShapeFn) {
296 ShapeInferenceTestOp op("ReverseV2");
297 INFER_OK(op, "?;?", "in0");
298 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
299 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
300 INFER_OK(op, "[1,2,3];[2]", "in0");
301 INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
302 op, "[1,2,3,4,5,6,7,8,9];[9]");
303 INFER_OK(op, "[1,2,3,?];[4]", "in0");
304 INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
305 }
306
TEST(ArrayOpsTest,Fill_ShapeFn)307 TEST(ArrayOpsTest, Fill_ShapeFn) {
308 ShapeInferenceTestOp op("Fill");
309 AddNodeAttr("index_type", DT_INT32, &op.node_def);
310 op.input_tensors.resize(2);
311 INFER_OK(op, "?;?", "?");
312 INFER_OK(op, "[?];?", "?");
313 INFER_OK(op, "[4];?", "[?,?,?,?]");
314
315 Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
316 op.input_tensors[0] = &in_t;
317 INFER_OK(op, "[4];?", "[1,2,3,4]");
318 }
319
TEST(ArrayOpsTest,Gather_ShapeFn)320 TEST(ArrayOpsTest, Gather_ShapeFn) {
321 ShapeInferenceTestOp op("Gather");
322 INFER_OK(op, "?;?", "?");
323 INFER_OK(op, "[1,?,2];[3]", "[d1_0,d0_1,d0_2]");
324 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1,2,3]");
325 }
326
TEST(ArrayOpsTest,GatherV2_ShapeFn)327 TEST(ArrayOpsTest, GatherV2_ShapeFn) {
328 ShapeInferenceTestOp op("GatherV2");
329 AddNodeAttr("batch_dims", 0, &op.node_def);
330
331 // Tests when axis is unknown.
332 INFER_OK(op, "?;?;?", "?");
333 INFER_OK(op, "[1,2,3];[3];[]", "[?,?,?]");
334 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
335 "[];[1,2,3];[]");
336
337 // Non-scalar axis.
338 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];[1,2,3];[1]");
339
340 // Test when axis dim is known.
341 Tensor axis_dim_t;
342 op.input_tensors.resize(3);
343 op.input_tensors[2] = &axis_dim_t;
344
345 // Out of range axis.
346 axis_dim_t = test::AsScalar(1);
347 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
348 "[1];[1,2];[]");
349
350 // Rank 0 indices.
351 axis_dim_t = test::AsScalar(0);
352 INFER_OK(op, "[1,2,3];[];[]", "[d0_1,d0_2]");
353 axis_dim_t = test::AsScalar(1);
354 INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_2]");
355 axis_dim_t = test::AsScalar(2);
356 INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_1]");
357
358 // Rank 1 indices.
359 axis_dim_t = test::AsScalar(0);
360 INFER_OK(op, "[1,2,3];[5];[]", "[d1_0,d0_1,d0_2]");
361 axis_dim_t = test::AsScalar(1);
362 INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d1_0,d0_2]");
363 axis_dim_t = test::AsScalar(2);
364 INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d0_1,d1_0]");
365
366 // Rank 2 indices.
367 axis_dim_t = test::AsScalar(0);
368 INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
369 axis_dim_t = test::AsScalar(1);
370 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
371 axis_dim_t = test::AsScalar(2);
372 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
373
374 // Negative axis.
375 axis_dim_t = test::AsScalar(-3);
376 INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
377 axis_dim_t = test::AsScalar(-2);
378 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
379 axis_dim_t = test::AsScalar(-1);
380 INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
381
382 // Batch dimensions > 0
383 // Create another node since we can't overwrite the original batch dims.
384 ShapeInferenceTestOp batch_op("GatherV2");
385 AddNodeAttr("batch_dims", 1, &batch_op.node_def);
386 INFER_OK(batch_op, "[1,4800,8];[1,28400];[]", "[?,?,?]");
387
388 ShapeInferenceTestOp batch_op_2("GatherV2");
389 AddNodeAttr("batch_dims", 2, &batch_op_2.node_def);
390 INFER_OK(batch_op_2, "[1,2,3,4,5];[1,2,3];[]", "[?,?,?,?,?]");
391 }
392
TEST(ArrayOpsTest,GatherNd_ShapeFn)393 TEST(ArrayOpsTest, GatherNd_ShapeFn) {
394 ShapeInferenceTestOp op("GatherNd");
395
396 // Inputs are (params, indices).
397 INFER_OK(op, "?;?", "?");
398 INFER_OK(op, "[1,?,3,?];[?,0]", "[d1_0,d0_0,d0_1,d0_2,d0_3]");
399 INFER_OK(op, "[1,?,3,?];[?,4]", "[d1_0]");
400
401 // params.rank >= indices.dim(-1).
402 INFER_ERROR("indices.shape[-1] must be <= params.rank", op, "[1,2,3];[4]");
403 }
404
TEST(ArrayOpsTest,Shape_ShapeFn)405 TEST(ArrayOpsTest, Shape_ShapeFn) {
406 ShapeInferenceTestOp op("Shape");
407 INFER_OK(op, "?", "[?]");
408 INFER_OK(op, "[?]", "[1]");
409 INFER_OK(op, "[?,2,3,4,5]", "[5]");
410 }
411
TEST(ArrayOpsTest,ShapeN_ShapeFn)412 TEST(ArrayOpsTest, ShapeN_ShapeFn) {
413 ShapeInferenceTestOp op("ShapeN");
414 int n = 3;
415 std::vector<NodeDefBuilder::NodeOut> src_list;
416 src_list.reserve(n);
417 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
418 TF_ASSERT_OK(NodeDefBuilder("test", "ShapeN")
419 .Input(src_list)
420 .Attr("N", n)
421 .Finalize(&op.node_def));
422 INFER_OK(op, "?;?;?", "[?];[?];[?]");
423 INFER_OK(op, "[?];[?];[?]", "[1];[1];[1]");
424 INFER_OK(op, "[?,2,3,4,5];?;[1,?,3]", "[5];[?];[3]");
425 }
426
TEST(ArrayOpsTest,Unique_ShapeFn)427 TEST(ArrayOpsTest, Unique_ShapeFn) {
428 ShapeInferenceTestOp op("Unique");
429 INFER_OK(op, "?", "[?];in0");
430 INFER_OK(op, "[5]", "[?];in0");
431 INFER_ERROR("Shape must be rank 1 but is rank 5", op, "[1,2,3,?,5]");
432 }
433
TEST(ArrayOpsTest,UniqueWithCounts_ShapeFn)434 TEST(ArrayOpsTest, UniqueWithCounts_ShapeFn) {
435 ShapeInferenceTestOp op("UniqueWithCounts");
436 INFER_OK(op, "?", "[?];in0;[?]");
437 INFER_OK(op, "[1,2,3,?,5]", "[?];in0;[?]");
438 }
439
TEST(ArrayOpsTest,InvertPermutation_ShapeFn)440 TEST(ArrayOpsTest, InvertPermutation_ShapeFn) {
441 ShapeInferenceTestOp op("InvertPermutation");
442 INFER_OK(op, "?", "[?]");
443 INFER_OK(op, "[1]", "in0");
444 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]");
445 }
446
TEST(ArrayOpsTest,PadD_ShapeFn)447 TEST(ArrayOpsTest, PadD_ShapeFn) {
448 for (const char* op_name : {"Pad", "MirrorPad"}) {
449 ShapeInferenceTestOp op(op_name);
450 op.input_tensors.resize(2);
451
452 // Inputs are input and paddings.
453
454 INFER_OK(op, "?;?", "?");
455
456 // Check shape of paddings.
457 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3]");
458 INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4]");
459
460 // input.rank and paddings.dim(0) are equal. This is the number of dims in
461 // output.
462 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2]");
463 INFER_OK(op, "[1,2,3];?", "[?,?,?]");
464 INFER_OK(op, "?;[3,2]", "[?,?,?]");
465
466 // Make the paddings tensor known and verify padding values get added.
467 // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
468 // to input dims to get output.
469 Tensor paddings_t(DT_INT64, TensorShape{3, 2});
470 test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
471 op.input_tensors[1] = &paddings_t;
472 INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
473 INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
474 INFER_OK(op, "?;[3,2]", "[?,?,?]");
475 INFER_OK(op, "?;?", "[?,?,?]");
476 }
477 }
478
TEST(ArrayOpsTest,PadV2_ShapeFn)479 TEST(ArrayOpsTest, PadV2_ShapeFn) {
480 ShapeInferenceTestOp op("PadV2");
481 op.input_tensors.resize(3);
482
483 // Inputs are input, paddings and constant_values.
484
485 INFER_OK(op, "?;?;?", "?");
486
487 // Check shape of paddings.
488 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3];?");
489 INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4];?");
490
491 // input.rank and paddings.dim(0) are equal. This is the number of dims in
492 // output.
493 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2];[]");
494 INFER_OK(op, "[1,2,3];?;[]", "[?,?,?]");
495 INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
496
497 // Make the paddings tensor known and verify padding values get added.
498 // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
499 // to input dims to get output.
500 Tensor paddings_t(DT_INT64, TensorShape{3, 2});
501 test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
502 op.input_tensors[1] = &paddings_t;
503 INFER_OK(op, "[100,200,300];[3,2];[]", "[111,222,333]");
504 INFER_OK(op, "[100,?,300];[3,2];[]", "[111,?,333]");
505 INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
506 INFER_OK(op, "?;?;[]", "[?,?,?]");
507 }
508
TEST(ArrayOpsTest,MirrorPadGrad_ShapeFn)509 TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) {
510 ShapeInferenceTestOp op("MirrorPadGrad");
511 op.input_tensors.resize(2);
512
513 // Inputs are input and paddings.
514 INFER_OK(op, "?;?", "?");
515
516 // First padding dimension is unknown, so rank is unknown.
517 INFER_OK(op, "?;[?,4]", "?");
518
519 // Input tensor rank doesn't match paddings dimension.
520 INFER_ERROR("must be rank 3 but is rank 2", op, "[?,?];[3,2]");
521
522 // Paddings tensor is not a [rank x 2] matrix.
523 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
524 "[?,?,?];[3,3]");
525
526 // Paddings tensor is unknown, but rank is known, so the output
527 // shape is a rank 3 unknown shape.
528 INFER_OK(op, "[?,?,?];[3,2]", "[?,?,?]");
529
530 // Make the paddings tensor known and verify padding values get
531 // subtracted. E.g., if padding is ((1,10),(2,20),(3,30)) then
532 // values 11,22,23 are subtracted to input dims to get output.
533 Tensor paddings_t(DT_INT64, TensorShape{3, 2});
534 test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
535 op.input_tensors[1] = &paddings_t;
536
537 INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]");
538 INFER_OK(op, "[111,?,333];[3,2]", "[100,?,300]");
539 }
540
TEST(ArrayOpsTest,BroadcastArgs_ShapeFn)541 TEST(ArrayOpsTest, BroadcastArgs_ShapeFn) {
542 ShapeInferenceTestOp op("BroadcastArgs");
543 INFER_OK(op, "?;?", "[?]");
544 INFER_OK(op, "[123];[1]", "[123]");
545 INFER_OK(op, "[1];[123]", "[123]");
546 INFER_OK(op, "[123];[121]", "[123]");
547
548 // Rank checks
549 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
550 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
551 }
552
TEST(ArrayOpsTest,BroadcastTo_ShapeFn)553 TEST(ArrayOpsTest, BroadcastTo_ShapeFn) {
554 ShapeInferenceTestOp op("BroadcastTo");
555 op.input_tensors.resize(2);
556
557 INFER_OK(op, "?;[?]", "?");
558 INFER_OK(op, "[];[1]", "[?]");
559 INFER_OK(op, "[1];[1]", "[?]");
560 INFER_OK(op, "[1];[2]", "[?,?]");
561 INFER_OK(op, "[2,2];[3]", "[?,d0_0,d0_1]");
562
563 // Rank checks
564 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,?]");
565 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2];[]");
566 INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[2,2];[1]");
567
568 Tensor shape_t(DT_INT64, TensorShape{3});
569 test::FillValues<int64>(&shape_t, {2, 10, 3});
570 op.input_tensors[1] = &shape_t;
571 INFER_OK(op, "[1,?,1];[3]", "[2,10,3]");
572 INFER_OK(op, "[1,1,1];[3]", "[2,10,3]");
573 INFER_OK(op, "[10,1];[3]", "[2,d0_0,3]");
574 INFER_ERROR("Dimensions must be equal, but are 3 and 2 for", op,
575 "[3,1,1];[3]");
576 INFER_ERROR("Dimensions must be equal, but are 2 and 10 for", op,
577 "[2,2,1];[3]");
578 }
579
TEST(ArrayOpsTest,BroadcastGradientArgs_ShapeFn)580 TEST(ArrayOpsTest, BroadcastGradientArgs_ShapeFn) {
581 ShapeInferenceTestOp op("BroadcastGradientArgs");
582 // Output is always two unknown vectors.
583 INFER_OK(op, "?;?", "[?];[?]");
584 INFER_OK(op, "[123];[456]", "[?];[?]");
585
586 // Rank checks
587 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
588 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
589 }
590
TEST(ArrayOpsTest,ListDiff_ShapeFn)591 TEST(ArrayOpsTest, ListDiff_ShapeFn) {
592 ShapeInferenceTestOp op("BroadcastGradientArgs");
593 // Output is always two matching unknown vectors.
594 INFER_OK(op, "?;?", "[?];[?]");
595 INFER_OK(op, "[123];[456]", "[?];[?]");
596
597 // Rank checks
598 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
599 INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
600 }
601
TEST(ArrayOpsTest,MatrixSetDiag_ShapeFn)602 TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
603 ShapeInferenceTestOp op("MatrixSetDiag");
604
605 // Inputs are input and diagonal.
606
607 // Rank checks.
608 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
609 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
610 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[2,2];[]");
611 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,2];[2,2]");
612
613 // diagonal[-1] must match smallest matrix dimension.
614 INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
615
616 // Output matches input.
617 INFER_OK(op, "?;?", "in0");
618 INFER_OK(op, "[1,2,2];[1,2]", "in0");
619 INFER_OK(op, "[1,2,3];?", "in0");
620 INFER_OK(op, "[1,3,2];?", "in0");
621 INFER_OK(op, "[1,?,2];[?,?]", "in0");
622 INFER_OK(op, "[1,?,?];[?,2]", "in0");
623
624 // Infer batch shape from diag when input is not fully specified.
625 INFER_OK(op, "?;[1,2]", "[d1_0,?,?]");
626 INFER_OK(op, "[?,?,3];[1,2]", "[d1_0,d0_1,d0_2]");
627 INFER_OK(op, "[?,3,?];[1,2]", "[d1_0,d0_1,d0_2]");
628 INFER_OK(op, "[?,3,2];[1,2]", "[d1_0,d0_1,d0_2]");
629 }
630
TEST(ArrayOpsTest,ExpandDims_ShapeFn)631 TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
632 ShapeInferenceTestOp op("ExpandDims");
633 op.input_tensors.resize(2);
634
635 // With unknown dim tensor value, output is unknown.
636 INFER_OK(op, "?;?", "?");
637 Tensor dim_t;
638 op.input_tensors[1] = &dim_t;
639
640 // Expand at front of tensor.
641 for (int32_t idx : {0, -4}) {
642 dim_t = test::AsScalar<int32>(idx);
643 INFER_OK(op, "?;?", "?");
644 INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
645 }
646
647 // Expand at middle of tensor.
648 for (int32_t idx : {1, -3}) {
649 dim_t = test::AsScalar<int32>(idx);
650 INFER_OK(op, "?;?", "?");
651 INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
652
653 // Repeat with int64.
654 dim_t = test::AsScalar<int64>(idx);
655 INFER_OK(op, "?;?", "?");
656 INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
657 }
658 for (int32_t idx : {2, -2}) {
659 dim_t = test::AsScalar<int32>(idx);
660 INFER_OK(op, "?;?", "?");
661 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
662
663 // Repeat with int64.
664 dim_t = test::AsScalar<int64>(idx);
665 INFER_OK(op, "?;?", "?");
666 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
667 }
668
669 for (int32_t idx : {3, -1}) {
670 // Expand at the end.
671 dim_t = test::AsScalar<int32>(idx);
672 INFER_OK(op, "?;?", "?");
673 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
674
675 // Repeat with int64.
676 dim_t = test::AsScalar<int64>(idx);
677 INFER_OK(op, "?;?", "?");
678 INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
679 }
680 for (int32_t idx : {4, -5}) {
681 // Invalid idx.
682 dim_t = test::AsScalar<int32>(idx);
683 INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
684 dim_t = test::AsScalar<int64>(idx);
685 INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
686 }
687
688 // Expand using an input vector tensor.
689 std::vector<int32> dims;
690 dims.push_back(0);
691 dim_t = test::AsTensor<int32>(dims);
692 INFER_OK(op, "?;?", "?");
693 INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
694
695 // Expand using too many input elements.
696 dims.push_back(1);
697 dim_t = test::AsTensor<int32>(dims);
698 INFER_ERROR("'dim' input must be a tensor with a single", op, "?;?");
699 INFER_ERROR("'dim' input must be a tensor with a single", op, "[5,6,7];?");
700
701 // Examples from ExpandDims doc.
702 dim_t = test::AsScalar<int32>(0);
703 INFER_OK(op, "[2];[]", "[1,d0_0]");
704 dim_t = test::AsScalar<int32>(1);
705 INFER_OK(op, "[2];[]", "[d0_0,1]");
706 dim_t = test::AsScalar<int32>(-1);
707 INFER_OK(op, "[2];[]", "[d0_0,1]");
708 }
709
TEST(ArrayOpsTest,ImmutableConst_ShapeFn)710 TEST(ArrayOpsTest, ImmutableConst_ShapeFn) {
711 ShapeInferenceTestOp op("ImmutableConst");
712
713 TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
714 .Attr("dtype", DT_FLOAT)
715 .Attr("shape", TensorShape({1, 2, 3}))
716 .Attr("memory_region_name", "test_region")
717 .Finalize(&op.node_def));
718 INFER_OK(op, "", "[1,2,3]");
719
720 TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
721 .Attr("dtype", DT_FLOAT)
722 .Attr("shape", TensorShape({}))
723 .Attr("memory_region_name", "test_region")
724 .Finalize(&op.node_def));
725 INFER_OK(op, "", "[]");
726
727 TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
728 .Attr("dtype", DT_FLOAT)
729 .Attr("shape", "invalid")
730 .Attr("memory_region_name", "test_region")
731 .Finalize(&op.node_def));
732 INFER_ERROR("AttrValue had value with type 'string' when 'shape' expected",
733 op, "");
734 }
735
TEST(ArrayOpsTest,Concat_ShapeFn)736 TEST(ArrayOpsTest, Concat_ShapeFn) {
737 ShapeInferenceTestOp op("Concat");
738 auto set_n = [&op](int n) {
739 std::vector<NodeDefBuilder::NodeOut> src_list;
740 src_list.reserve(n);
741 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
742 TF_ASSERT_OK(NodeDefBuilder("test", "Concat")
743 .Input({"concat_dim", 0, DT_INT32})
744 .Input(src_list)
745 .Attr("n", n)
746 .Finalize(&op.node_def));
747 };
748
749 // Confirm dimension[0] of the input (the concat_dim) is a scalar.
750 set_n(2);
751 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?");
752
753 // Test with the input concat_dim tensor not known. This takes the known rank
754 // of the inputs and makes a tensor of that many unknown dims.
755 set_n(7);
756 INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
757 set_n(4);
758 INFER_OK(op, "?;?;?;[1,2,3,4];[4,3,2,1]", "[?,?,?,?]");
759 INFER_OK(op, "?;?;?;?;?", "?"); // output rank unknown
760 INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
761 "?;?;?;[];[]");
762 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;?;[1,2];[1,2,3]");
763
764 // Test when the concat_dim tensor is known. The concatenated dimension is
765 // summed across all input tensors, and other dimensions are merged.
766 Tensor concat_dim_t;
767 op.input_tensors.push_back(&concat_dim_t);
768 set_n(2);
769
770 // Sum dim 0, merge the other two dims.
771 for (int concat_dim : {0, -3}) {
772 concat_dim_t = test::AsScalar(concat_dim);
773 INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
774 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
775 "[];[100,2,5];[10,?,3]");
776 // concat_dim can't be summed, as one value is unknown.
777 INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]");
778 INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]");
779 }
780
781 // Test with a higher concat_dim.
782 for (bool use_negative : {false, true}) {
783 concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
784 INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]");
785 concat_dim_t = test::AsScalar(use_negative ? -1 : 1);
786 INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
787 INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
788
789 // concat_dim is out of bounds.
790 concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
791 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
792 "[];[100];[10,?]");
793 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
794 "[];[100,5];[10]");
795 }
796
797 // concat_dim is too low.
798 concat_dim_t = test::AsScalar(-2);
799 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
800 "[];[100];[10,?]");
801 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
802 "[];[100,5];[10]");
803
804 // Repeat successful case with several unknown inputs.
805 set_n(5);
806 concat_dim_t = test::AsScalar(1);
807 INFER_OK(op, "[];?;[1,100,?];[?,?,?];[?,10,3];?", "[d2_0,?,d4_2]");
808 }
809
TEST(ArrayOpsTest,ConcatV2_ShapeFn)810 TEST(ArrayOpsTest, ConcatV2_ShapeFn) {
811 ShapeInferenceTestOp op("ConcatV2");
812 auto set_n = [&op](int n) {
813 std::vector<NodeDefBuilder::NodeOut> src_list;
814 src_list.reserve(n);
815 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
816 TF_ASSERT_OK(NodeDefBuilder("test", "ConcatV2")
817 .Input(src_list)
818 .Input({"axis", 0, DT_INT32})
819 .Attr("n", n)
820 .Finalize(&op.node_def));
821 };
822
823 // Confirm dimension[0] of the input (the concat_dim) is a scalar.
824 set_n(2);
825 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");
826
827 // Test with the input concat_dim tensor not known. This takes the known rank
828 // of the inputs and makes a tensor of that many unknown dims.
829 set_n(7);
830 INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
831 set_n(4);
832 INFER_OK(op, "?;?;[1,2,3,4];[4,3,2,1];?", "[?,?,?,?]");
833 INFER_OK(op, "?;?;?;?;?", "?"); // output rank unknown
834 INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
835 "?;?;[];[];?");
836 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;[1,2];[1,2,3];?");
837
838 // Test when the concat_dim tensor is known. The concatenated dimension is
839 // summed across all input tensors, and other dimensions are merged.
840 Tensor concat_dim_t;
841 op.input_tensors.resize(3);
842 op.input_tensors[2] = &concat_dim_t;
843
844 set_n(2);
845
846 // Invalid concat dim value.
847 // concat_dim_t = test::AsScalar(-1);
848 // INFER_ERROR("Expected concat_dim >= 0, but got -1", op, "?;?;?");
849
850 // Sum dim 0, merge the other two dims.
851 concat_dim_t = test::AsScalar(0);
852 INFER_OK(op, "[100,2,?];[10,?,3];[]", "[110,d0_1,d1_2]");
853 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
854 "[100,2,5];[10,?,3];[]");
855 // concat_dim can't be summed, as one value is unknown.
856 INFER_OK(op, "[100,2,?];[?,?,3];[]", "[?,d0_1,d1_2]");
857 INFER_OK(op, "[?,2,?];[10,?,3];[]", "[?,d0_1,d1_2]");
858
859 // Test with a higher concat_dim.
860 concat_dim_t = test::AsScalar(1);
861 INFER_OK(op, "[1,100,?];[?,10,3];[]", "[d0_0,110,d1_2]");
862 INFER_OK(op, "[1,100];[?,10];[]", "[d0_0,110]");
863 INFER_OK(op, "[?,100];[1,10];[]", "[d1_0,110]");
864 // concat_dim is too high.
865 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
866 "[100];[10,?];[]");
867 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
868 "[100,5];[10];[]");
869 // concat_dim is too low.
870 concat_dim_t = test::AsScalar(-2);
871 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
872 "[100];[10,?];[]");
873 INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
874 "[100,5];[10];[]");
875
876 // Repeat successful case with several unknown inputs.
877 op.input_tensors.resize(6);
878 op.input_tensors[3] = nullptr;
879 op.input_tensors[5] = &concat_dim_t;
880 concat_dim_t = test::AsScalar(1);
881
882 set_n(5);
883 INFER_OK(op, "?;[1,100,?];[?,?,?];[?,10,3];?;[]", "[d1_0,?,d3_2]");
884 }
885
TEST(ArrayOpsTest,ConcatOffset_ShapeFn)886 TEST(ArrayOpsTest, ConcatOffset_ShapeFn) {
887 ShapeInferenceTestOp op("ConcatOffset");
888
889 const int n = 4;
890 std::vector<NodeDefBuilder::NodeOut> src_list;
891 src_list.reserve(n);
892 for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT32);
893 TF_ASSERT_OK(NodeDefBuilder("test", "ConcatOffset")
894 .Input({"concat_dim", 0, DT_INT32})
895 .Input(src_list)
896 .Attr("n", n)
897 .Finalize(&op.node_def));
898 INFER_OK(op, "?;?;?;?;?", "in1;in2;in3;in4");
899 }
900
TEST(ArrayOpsTest,Reshape_ShapeFn)901 TEST(ArrayOpsTest, Reshape_ShapeFn) {
902 ShapeInferenceTestOp op("Reshape");
903 op.input_tensors.resize(2);
904
905 // No valid shape provided.
906 INFER_OK(op, "?;?", "?");
907 INFER_OK(op, "[?];?", "?");
908 INFER_OK(op, "?;[?]", "?");
909 INFER_OK(op, "[?];[?]", "?");
910 INFER_OK(op, "[4];[?]", "?");
911
912 // All dimensions provided.
913 Tensor new_shape = test::AsTensor<int32>({1, 2, 3});
914 op.input_tensors[1] = &new_shape;
915 INFER_OK(op, "?;[3]", "[1,2,3]");
916 INFER_OK(op, "[?];[3]", "[1,2,3]");
917 INFER_OK(op, "[6];[3]", "[1,2,3]");
918 // The number of elements should match for the reshape to succeed.
919 INFER_ERROR(
920 "Cannot reshape a tensor with 12 elements to shape [1,2,3] (6 elements)",
921 op, "[3,4];[3]");
922
923 // Unknown dimensions.
924 // Flatten:
925 new_shape = test::AsTensor<int32>({-1});
926 INFER_OK(op, "?;[1]", "[?]");
927 INFER_OK(op, "[?];[1]", "[d0_0]");
928 INFER_OK(op, "[2,2];[1]", "[4]");
929 // The first dimension is inferred:
930 new_shape = test::AsTensor<int32>({2, -1});
931 INFER_OK(op, "[3,4];[2]", "[2,6]");
932 // The total number of elements must be evenly divisible by the known
933 // dimensions.
934 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 7", op,
935 "[7];[2]");
936 // Multiple missing dimensions cannot be inferred.
937 new_shape = test::AsTensor<int32>({-1, -1, 2});
938 INFER_OK(op, "[8];[3]", "[?,?,2]");
939 INFER_OK(op, "?;[3]", "[?,?,2]");
940
941 // Symbolic shape propagation
942 new_shape = test::AsTensor<int32>({-1, 2, 3});
943 INFER_OK(op, "[?,2,3];[3]", "[d0_0,2,3]");
944
945 // Reshaping to a scalar.
946 new_shape = test::AsTensor<int32>({});
947 INFER_OK(op, "[1];[0]", "[]");
948 INFER_ERROR(
949 "Cannot reshape a tensor with 2 elements to shape [] (1 elements)", op,
950 "[1,2];[0]");
951
952 // Reshaping a tensor with no elements.
953 new_shape = test::AsTensor<int32>({-1});
954 INFER_OK(op, "[0];[1]", "[0]");
955 new_shape = test::AsTensor<int32>({-1, 6});
956 INFER_OK(op, "[0,2];[1]", "[0,6]");
957 new_shape = test::AsTensor<int32>({0, -1});
958 INFER_OK(op, "[0,2];[1]", "[0,?]");
959 }
960
TEST(ArrayOpsTest,QuantizedReshape_ShapeFn)961 TEST(ArrayOpsTest, QuantizedReshape_ShapeFn) {
962 ShapeInferenceTestOp op("QuantizedReshape");
963 op.input_tensors.resize(2);
964
965 // First test a subset of the Reshape_ShapeFn tests. Not all are tested, as
966 // QuantizedReshape uses the same code for the reshape part of the operation.
967 INFER_OK(op, "?;?;?;?", "?;[];[]");
968 INFER_OK(op, "[?];?;?;?", "?;[];[]");
969 INFER_OK(op, "[?];[?];?;?", "?;[];[]");
970 INFER_OK(op, "[4];[?];?;?", "?;[];[]");
971 Tensor new_shape = test::AsTensor<int32>({1, 2, 3});
972 op.input_tensors[1] = &new_shape;
973 INFER_OK(op, "[?];[3];?;?", "[1,2,3];[];[]");
974 INFER_OK(op, "[6];[3];?;?", "[1,2,3];[];[]");
975 INFER_ERROR(
976 "Cannot reshape a tensor with 12 elements to shape [1,2,3] (6 elements)",
977 op, "[3,4];[3];?;?");
978
979 // Test the scalar rank checks on input_min and input_max.
980 INFER_ERROR("must be rank 0", op, "?;?;[1];?");
981 INFER_ERROR("must be rank 0", op, "?;?;?;[1]");
982 }
983
TEST(ArrayOpsTest,Placeholder_ShapeFn)984 TEST(ArrayOpsTest, Placeholder_ShapeFn) {
985 {
986 // 2D shape
987 ShapeInferenceTestOp op("Placeholder");
988 TensorShape shape({1, 2});
989 TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
990 .Attr("shape", shape)
991 .Attr("dtype", DT_FLOAT)
992 .Finalize(&op.node_def));
993 INFER_OK(op, "", "[1,2]");
994 }
995
996 {
997 // Scalar shapes are supported
998 ShapeInferenceTestOp op("Placeholder");
999 TensorShape shape({});
1000 TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
1001 .Attr("shape", shape)
1002 .Attr("dtype", DT_FLOAT)
1003 .Finalize(&op.node_def));
1004 INFER_OK(op, "", "[]");
1005 }
1006
1007 {
1008 // Partial shape
1009 ShapeInferenceTestOp op("Placeholder");
1010 const int64 dims[2] = {1, -1};
1011 PartialTensorShape shape;
1012 TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 2, &shape));
1013 TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
1014 .Attr("shape", shape)
1015 .Attr("dtype", DT_FLOAT)
1016 .Finalize(&op.node_def));
1017 INFER_OK(op, "", "[1,?]");
1018 }
1019
1020 {
1021 // Unknown shape
1022 ShapeInferenceTestOp op("Placeholder");
1023 PartialTensorShape shape;
1024 TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
1025 .Attr("shape", shape)
1026 .Attr("dtype", DT_FLOAT)
1027 .Finalize(&op.node_def));
1028 INFER_OK(op, "", "?");
1029 }
1030 }
1031
TEST(ArrayOpsTest,Transpose_ShapeFn)1032 TEST(ArrayOpsTest, Transpose_ShapeFn) {
1033 ShapeInferenceTestOp op("Transpose");
1034 op.input_tensors.resize(2);
1035
1036 // Missing shape information.
1037 INFER_OK(op, "?;?", "?");
1038 INFER_OK(op, "?;[?]", "?");
1039 INFER_OK(op, "?;[2]", "[?,?]");
1040 INFER_OK(op, "[?];?", "[?]");
1041 INFER_OK(op, "[?,?];[2]", "[?,?]");
1042 INFER_ERROR("Dimension must be 3 but is 2", op, "[1,2,3];[2]");
1043 Tensor perm = test::AsTensor<int32>({0});
1044 op.input_tensors[1] = &perm;
1045 INFER_OK(op, "[?];[?]", "[d0_0]");
1046 perm = test::AsTensor<int32>({1, 0});
1047 INFER_OK(op, "?;[2]", "[?,?]");
1048 INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
1049 INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
1050 INFER_OK(op, "?;[0]", "in0");
1051
1052 // Invalid arguments.
1053 perm = test::AsTensor<int32>({1, 2});
1054 INFER_ERROR("perm dim 2 is out of range of input rank 2", op, "[1,2];[2]");
1055 perm = test::AsTensor<int32>({0});
1056 INFER_ERROR("Dimension must be 2 but is 1", op, "[1,2];[1]");
1057
1058 // Larger valid cases.
1059 perm = test::AsTensor<int32>({1, 0, 3, 4, 2});
1060 INFER_OK(op, "[0,1,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
1061 INFER_OK(op, "[0,?,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
1062 }
1063
TEST(ArrayOpsTest,Bitcast_ShapeFn)1064 TEST(ArrayOpsTest, Bitcast_ShapeFn) {
1065 ShapeInferenceTestOp op("Bitcast");
1066 auto rebuild_node_def = [&op](DataType input_type, DataType output_type) {
1067 TF_ASSERT_OK(NodeDefBuilder("test", "Bitcast")
1068 .Input("input", 0, input_type)
1069 .Attr("type", output_type)
1070 .Finalize(&op.node_def));
1071 };
1072
1073 rebuild_node_def(DT_FLOAT, DT_INT32);
1074 // No valid shape provided, so output is unknown.
1075 INFER_OK(op, "?", "?");
1076
1077 // Bitcasting from two equal sizes propagates shape.
1078 INFER_OK(op, "[1,2]", "in0");
1079
1080 // Bitcasting from smaller to larger reduces the size of the last dimension.
1081 rebuild_node_def(DT_INT32, DT_INT64);
1082 INFER_OK(op, "[1,2]", "[d0_0]"); // last dimension matches divisor.
1083 // TODO(vrv): Seems like a bug, or at least, too lenient.
1084 INFER_OK(op, "[1,?]", "[d0_0]");
1085 // 4 is divisible by 2, but the shape function signature requires
1086 // that the last dimension matches the last value exactly.
1087 INFER_ERROR("does not match", op, "[1,4]");
1088 INFER_ERROR("does not match", op, "[1,3]");
1089
1090 // Bitcasting from a larger type to a smaller type extends the dimension
1091 rebuild_node_def(DT_INT64, DT_INT32);
1092 INFER_OK(op, "[4,5]", "[d0_0,d0_1,2]");
1093 rebuild_node_def(DT_COMPLEX128, DT_INT32);
1094 INFER_OK(op, "[4,5]", "[d0_0,d0_1,4]");
1095 rebuild_node_def(DT_COMPLEX128, DT_HALF);
1096 INFER_OK(op, "[4,5]", "[d0_0,d0_1,8]");
1097 rebuild_node_def(DT_COMPLEX128, DT_INT8);
1098 INFER_OK(op, "[4,5]", "[d0_0,d0_1,16]");
1099
1100 // Bitcasting from a POD or quantized datatype is not allowed.
1101 rebuild_node_def(DT_STRING, DT_INT32);
1102 INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1103 rebuild_node_def(DT_INT32, DT_STRING);
1104 INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1105 }
1106
TEST(ArrayOpsTest,Squeeze_ShapeFn)1107 TEST(ArrayOpsTest, Squeeze_ShapeFn) {
1108 ShapeInferenceTestOp op("Squeeze");
1109
1110 auto rebuild_node_def = [&op](const std::vector<int32>& squeeze_dims) {
1111 TF_ASSERT_OK(NodeDefBuilder("test", "Squeeze")
1112 .Input("input", 0, DT_FLOAT)
1113 .Attr("squeeze_dims", squeeze_dims)
1114 .Finalize(&op.node_def));
1115 };
1116
1117 // Default squeeze_dims = []
1118 rebuild_node_def({});
1119
1120 // No valid shape provided, so output is unknown.
1121 INFER_OK(op, "?", "?");
1122
1123 INFER_OK(op, "[1,4,1,5,1]", "[d0_1,d0_3]");
1124
1125 // Squeezing all dimensions, but see some unknown values.
1126 INFER_OK(op, "[1,?,1,?,1]", "?");
1127
1128 // Test simple squeeze of an explicit dimension
1129 rebuild_node_def({1});
1130 INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1131 // Squeezing unknown dim explicitly, assumes it's 1 at runtime.
1132 INFER_OK(op, "[4,?,5]", "[d0_0,d0_2]");
1133
1134 // Attempt to squeeze non-one dimension
1135 INFER_ERROR("Can not squeeze dim[1]", op, "[4,6,5]");
1136
1137 // Squeeze multiple dimensions
1138 rebuild_node_def({1, 2});
1139 INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1140 rebuild_node_def({1, -2});
1141 INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1142
1143 // Negative squeeze dim
1144 rebuild_node_def({-2});
1145 INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1146
1147 // Test validation of squeeze dimensions
1148 rebuild_node_def({-4});
1149 INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1150 rebuild_node_def({3});
1151 INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1152 }
1153
TEST(ArrayOpsTest,ReverseSequence_ShapeFn)1154 TEST(ArrayOpsTest, ReverseSequence_ShapeFn) {
1155 ShapeInferenceTestOp op("ReverseSequence");
1156 auto rebuild_node_def = [&op](const int32_t seq_dim,
1157 const int32_t batch_dim) {
1158 TF_ASSERT_OK(NodeDefBuilder("test", "ReverseSequence")
1159 .Input("input", 0, DT_FLOAT)
1160 .Input("seq_lengths", 1, DT_INT64)
1161 .Attr("seq_dim", seq_dim)
1162 .Attr("batch_dim", batch_dim)
1163 .Finalize(&op.node_def));
1164 };
1165
1166 rebuild_node_def(1, 2);
1167 // No valid shape provided, so output is unknown.
1168 INFER_OK(op, "?;[10]", "?");
1169
1170 // Bad rank for seq_lengths
1171 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[10,10]");
1172
1173 // Validate seq_dim and batch_dim
1174 rebuild_node_def(1, 4);
1175 INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
1176 rebuild_node_def(4, 1);
1177 INFER_ERROR("seq_dim must be < input rank", op, "[1,2,3];[3]");
1178
1179 rebuild_node_def(1, 2);
1180 INFER_OK(op, "[1,2,3];[3]", "[d0_0,d0_1,d0_2]");
1181 // Resolves uncertainty on batch dimension by merging.
1182 INFER_OK(op, "[1,2,?];[3]", "[d0_0,d0_1,d1_0]");
1183 INFER_OK(op, "[1,2,3];[?]", "[d0_0,d0_1,d0_2]");
1184 }
1185
TEST(ArrayOpsTest,Split_ShapeFn)1186 TEST(ArrayOpsTest, Split_ShapeFn) {
1187 ShapeInferenceTestOp op("Split");
1188 op.input_tensors.resize(2);
1189
1190 // No value for split_dim and no input.
1191 TF_ASSERT_OK(NodeDefBuilder("test", "Split")
1192 .Input("split_dim", 0, DT_INT32)
1193 .Input("value", 1, DT_FLOAT)
1194 .Attr("num_split", 2)
1195 .Finalize(&op.node_def));
1196 INFER_OK(op, "?;?", "?;?");
1197 // If the rank is known, we know the rank of each output.
1198 INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
1199
1200 // split_dim is unknown but other inputs are known.
1201 INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
1202
1203 // split_dim is known.
1204 Tensor split_dim = test::AsTensor<int32>({1, 2});
1205 op.input_tensors[0] = &split_dim;
1206 INFER_ERROR("Input must be scalar but has rank 1", op, "[?];[?,?]");
1207 split_dim = test::AsScalar<int32>(1);
1208 INFER_OK(op, "?;?", "?;?");
1209 INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1210 INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1211 INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1212 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1213 "?;[1,5]");
1214
1215 // split_dim too large.
1216 split_dim = test::AsScalar<int32>(3);
1217 INFER_ERROR(
1218 "Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
1219 "?;[1,4,8]");
1220
1221 // Negative split_dim.
1222 split_dim = test::AsScalar<int32>(-1);
1223 INFER_OK(op, "?;?", "?;?");
1224 INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1225 INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1226 INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1227 INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
1228 split_dim = test::AsScalar<int32>(-2);
1229 INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
1230 split_dim = test::AsScalar<int32>(-4);
1231 INFER_ERROR(
1232 "Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
1233 "?;[1,4,8]");
1234 }
1235
TEST(ArrayOpsTest,Tile_ShapeFn)1236 TEST(ArrayOpsTest, Tile_ShapeFn) {
1237 ShapeInferenceTestOp op("Tile");
1238 op.input_tensors.resize(2);
1239
1240 // No value for split_dim and no input.
1241 TF_ASSERT_OK(NodeDefBuilder("test", "Tile")
1242 .Input("input", 0, DT_FLOAT)
1243 .Input("multiples", 1, DT_INT32)
1244 .Finalize(&op.node_def));
1245
1246 // If both are unknown, output is unknown.
1247 INFER_OK(op, "?;?", "?");
1248
1249 // If multiples rank is unknown but input is, output rank is known.
1250 INFER_OK(op, "[2,3,1,4];?", "[?,?,?,?]");
1251
1252 // Bad rank for 'multiples'
1253 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,3,1,4];[4,1]");
1254
1255 // No multiples tensor available, but output rank is known from multiples.
1256 INFER_OK(op, "?;[4]", "[?,?,?,?]");
1257
1258 // Test a tile of a 4D input.
1259 Tensor multiples = test::AsTensor<int32>({2, 3, 4, 5});
1260 op.input_tensors[1] = &multiples;
1261 INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1262 // Test 64-bit tensor type
1263 multiples = test::AsTensor<int64>({2, 3, 4, 5});
1264 INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1265 }
1266
TEST(ArrayOpsTest,EditDistance_ShapeFn)1267 TEST(ArrayOpsTest, EditDistance_ShapeFn) {
1268 ShapeInferenceTestOp op("EditDistance");
1269 op.input_tensors.resize(6);
1270
1271 // If the shape tensors are not available, the output shape is unknown.
1272 INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "?");
1273
1274 Tensor hypothesis_shape = test::AsTensor<int64>({2, 30, 4, 50});
1275 op.input_tensors[2] = &hypothesis_shape;
1276 Tensor truth_shape = test::AsTensor<int64>({20, 3, 40, 5});
1277 op.input_tensors[5] = &truth_shape;
1278 INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "[20,30,40]");
1279
1280 // Shape elements don't match
1281 hypothesis_shape = test::AsTensor<int64>({2});
1282 op.input_tensors[2] = &hypothesis_shape;
1283 INFER_ERROR("Num elements of hypothesis_shape does not match truth_shape", op,
1284 "[?,?];[?];[1];[?,?];[?];[4]");
1285 }
1286
TEST(ArrayOpsTest,OneHot_ShapeFn)1287 TEST(ArrayOpsTest, OneHot_ShapeFn) {
1288 ShapeInferenceTestOp op("OneHot");
1289 op.input_tensors.resize(4);
1290 auto set_axis = [&op](int axis) {
1291 TF_ASSERT_OK(NodeDefBuilder("test", "OneHot")
1292 .Input("indices", 0, DT_FLOAT)
1293 .Input("depth", 1, DT_INT32)
1294 .Input("on_value", 2, DT_FLOAT)
1295 .Input("off_value", 3, DT_FLOAT)
1296 .Attr("axis", axis)
1297 .Finalize(&op.node_def));
1298 };
1299
1300 // Invalid axis value.
1301 set_axis(-2);
1302 INFER_ERROR("axis must be >= -1", op, "?;?;?;?");
1303 set_axis(1);
1304
1305 // If indices shape is unknown, we return an unknown shape.
1306 INFER_OK(op, "?;[];?;?", "?");
1307
1308 // Depth must be scalar.
1309 Tensor depth = test::AsTensor<int32>({1, 2});
1310 op.input_tensors[1] = &depth;
1311 INFER_ERROR("Input must be scalar but has rank 1", op, "?;[2];?;?");
1312
1313 // Full information is available.
1314 depth = test::AsScalar<int32>(2);
1315 INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,2,d0_1,d0_2]");
1316 set_axis(-1);
1317 INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
1318 }
1319
TEST(ArrayOpsTest,ExtractImagePatchesShapeTest)1320 TEST(ArrayOpsTest, ExtractImagePatchesShapeTest) {
1321 ShapeInferenceTestOp op("ExtractImagePatches");
1322 auto set_op = [&op](const std::vector<int32>& ksizes,
1323 const std::vector<int32>& strides,
1324 const std::vector<int32>& rates, const string& padding) {
1325 TF_ASSERT_OK(NodeDefBuilder("test", "ExtractImagePatches")
1326 .Input("input", 0, DT_FLOAT)
1327 .Attr("ksizes", ksizes)
1328 .Attr("strides", strides)
1329 .Attr("rates", rates)
1330 .Attr("padding", padding)
1331 .Finalize(&op.node_def));
1332 };
1333
1334 // Just tests that the ksize calculation with rates works. Most of
1335 // the other code is boilerplate that is tested by a variety of
1336 // other ops.
1337 //
1338 // ksizes is 2x2. rate rows and cols is 2, so ksize_rows and
1339 // cols are changed to be 2 + (2 - 1) = 3. 7x7 input with 3x3
1340 // filter and 1x1 stride gives a 5x5 output.
1341 set_op({1, 2, 2, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1342 INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,8]");
1343 // With ksizes as 1x1, the output depth is now exactly the last value of the
1344 // input and output spatial is reduced as well.
1345 set_op({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1346 INFER_OK(op, "[1,7,7,2]", "[d0_0,7,7,d0_3]");
1347
1348 // Bad ksize rank
1349 set_op({1, 2, 2, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1350 INFER_ERROR(
1351 "ExtractImagePatches requires the ksizes attribute to contain 4 values, "
1352 "but got: 5",
1353 op, "[1,7,7,2]");
1354 }
1355
TEST(ArrayOpsTest,QuantizeAndDequantizeV2_ShapeFn)1356 TEST(ArrayOpsTest, QuantizeAndDequantizeV2_ShapeFn) {
1357 ShapeInferenceTestOp op("QuantizeAndDequantizeV2");
1358 op.input_tensors.resize(3);
1359 TF_ASSERT_OK(NodeDefBuilder("test", "QuantizeAndDequantizeV2")
1360 .Input("input", 0, DT_FLOAT)
1361 .Input("input_min", 1, DT_FLOAT)
1362 .Input("input_max", 2, DT_FLOAT)
1363 .Attr("signed_input", true)
1364 .Attr("num_bits", 8)
1365 .Attr("range_given", false)
1366 .Attr("narrow_range", false)
1367 .Attr("axis", -1)
1368 .Finalize(&op.node_def));
1369 INFER_OK(op, "?;?;?", "in0");
1370 INFER_OK(op, "[];?;?", "in0");
1371 INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
1372
1373 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[]");
1374 INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
1375 "[1,2,?,4,5];[];[1]");
1376 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[1]");
1377 }
1378
TEST(ArrayOpsTest,SpaceToBatch_ShapeFn)1379 TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) {
1380 ShapeInferenceTestOp op("SpaceToBatch");
1381 op.input_tensors.resize(2);
1382 TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToBatch")
1383 .Input("input", 0, DT_FLOAT)
1384 .Input("paddings", 1, DT_INT32)
1385 .Attr("block_size", 2)
1386 .Finalize(&op.node_def));
1387
1388 // Paddings not known, but batch size can be computed.
1389 INFER_OK(op, "[1,10,10,3];[2,2]", "[4,?,?,d0_3]");
1390
1391 // Unknown paddings means width and height.
1392 INFER_OK(op, "[1,10,10,3];?", "[4,?,?,d0_3]");
1393
1394 // Paddings not correct shape
1395 INFER_ERROR("rank", op, "[1,10,10,3];[4]");
1396 INFER_ERROR("3 and 2", op, "[1,10,10,3];[2,3]");
1397
1398 Tensor paddings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
1399 op.input_tensors[1] = &paddings;
1400 INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1401 paddings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
1402 INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1403
1404 // Bad paddings values
1405 paddings = test::AsTensor<int32>({1, 2, 3, 4}, {{2, 2}});
1406 op.input_tensors[1] = &paddings;
1407 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 13", op,
1408 "[1,10,10,3];[2,2]");
1409
1410 // Negative paddings
1411 paddings = test::AsTensor<int32>({1, -2, 3, 4}, {{2, 2}});
1412 op.input_tensors[1] = &paddings;
1413 INFER_ERROR("cannot be negative", op, "[1,10,10,3];[2,2]");
1414 }
1415
TEST(ArrayOpsTest,SpaceToBatchND_ShapeFn)1416 TEST(ArrayOpsTest, SpaceToBatchND_ShapeFn) {
1417 ShapeInferenceTestOp op("SpaceToBatchND");
1418 op.input_tensors.resize(3);
1419 TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToBatchND")
1420 .Input("input", 0, DT_FLOAT)
1421 .Input("block_shape", 1, DT_INT32)
1422 .Input("paddings", 2, DT_INT32)
1423 .Finalize(&op.node_def));
1424
1425 // Verify that input shape and paddings shape can be unknown.
1426 INFER_OK(op, "?;[2];?", "?");
1427
1428 // Only number of input dimensions is known.
1429 INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1430
1431 // Dimensions are partially known.
1432 INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1433
1434 {
1435 // Dimensions are partially known, block_shape known.
1436 Tensor block_shape = test::AsTensor<int32>({2, 3});
1437 op.input_tensors[1] = &block_shape;
1438 INFER_OK(op, "[3,?,?,2];[2];?", "[18,?,?,d0_3]");
1439
1440 // Dimensions are partially known, block_shape and paddings known.
1441 {
1442 Tensor paddings = test::AsTensor<int32>({1, 1, 0, 1}, {{2, 2}});
1443 op.input_tensors[2] = &paddings;
1444 INFER_OK(op, "[3,?,2,2];[2];[2,2]", "[18,?,1,d0_3]");
1445 op.input_tensors[2] = nullptr;
1446 }
1447
1448 // Dimensions are fully known, block_shape and paddings are known.
1449 {
1450 Tensor paddings = test::AsTensor<int32>({1, 1, 0, 0}, {{2, 2}});
1451 op.input_tensors[2] = &paddings;
1452 INFER_OK(op, "[3,2,3,2];[2];[2,2]", "[18,2,1,d0_3]");
1453 op.input_tensors[2] = nullptr;
1454 }
1455
1456 op.input_tensors[1] = nullptr;
1457 }
1458
1459 INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1460 INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1461
1462 {
1463 Tensor block_shape = test::AsTensor<int32>({0, 2});
1464 op.input_tensors[1] = &block_shape;
1465 INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1466 op.input_tensors[1] = nullptr;
1467 }
1468
1469 {
1470 Tensor block_shape = test::AsTensor<int32>({1, 1});
1471 op.input_tensors[1] = &block_shape;
1472 Tensor paddings = test::AsTensor<int32>({0, -1, 0, 0}, {{2, 2}});
1473 op.input_tensors[2] = &paddings;
1474 INFER_ERROR("paddings cannot be negative", op, "[1,2,2];[2];[2,2]");
1475 op.input_tensors[1] = nullptr;
1476 op.input_tensors[2] = nullptr;
1477 }
1478
1479 {
1480 Tensor block_shape = test::AsTensor<int32>({3, 3});
1481 op.input_tensors[1] = &block_shape;
1482 Tensor paddings = test::AsTensor<int32>({0, 0, 0, 0}, {{2, 2}});
1483 op.input_tensors[2] = &paddings;
1484 INFER_ERROR("divisible", op, "[1,2,3,1];[2];[2,2]");
1485 op.input_tensors[1] = nullptr;
1486 op.input_tensors[2] = nullptr;
1487 }
1488
1489 {
1490 Tensor block_shape = test::AsTensor<int32>({});
1491 op.input_tensors[1] = &block_shape;
1492 Tensor paddings = test::AsTensor<int32>({});
1493 op.input_tensors[2] = &paddings;
1494 INFER_OK(op, "?;[0];[0,2]", "?");
1495 op.input_tensors[1] = nullptr;
1496 op.input_tensors[2] = nullptr;
1497 }
1498
1499 INFER_ERROR("rank", op, "[1,3,3,1];[2];[1]");
1500 INFER_ERROR("shape", op, "[1,3,3,1];[2];[1,2]");
1501 }
1502
TEST(ArrayOpsTest,BatchToSpace_ShapeFn)1503 TEST(ArrayOpsTest, BatchToSpace_ShapeFn) {
1504 ShapeInferenceTestOp op("BatchToSpace");
1505 op.input_tensors.resize(2);
1506 TF_ASSERT_OK(NodeDefBuilder("test", "BatchToSpace")
1507 .Input("input", 0, DT_FLOAT)
1508 .Input("crops", 1, DT_INT32)
1509 .Attr("block_size", 2)
1510 .Finalize(&op.node_def));
1511
1512 // croppings not known, but batch size can be computed.
1513 INFER_OK(op, "[4,8,8,3];[2,2]", "[1,?,?,d0_3]");
1514
1515 // block_size not compatible with batch size
1516 INFER_ERROR("Dimension size must be evenly divisible by", op,
1517 "[5,8,8,3];[2,2]");
1518
1519 // Unknown croppings means unknown width and height.
1520 INFER_OK(op, "[4,8,8,3];?", "[1,?,?,d0_3]");
1521
1522 // croppings not correct shape
1523 INFER_ERROR("rank", op, "[4,8,8,3];[4]");
1524 INFER_ERROR("3 and 2", op, "[4,8,8,3];[2,3]");
1525
1526 Tensor croppings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
1527 op.input_tensors[1] = &croppings;
1528 INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]");
1529
1530 // Bad croppings values
1531 croppings = test::AsTensor<int32>({100, 2, 3, 4}, {{2, 2}});
1532 op.input_tensors[1] = &croppings;
1533 INFER_ERROR("Negative dimension size caused by subtracting", op,
1534 "[4,8,8,3];[2,2]");
1535 croppings = test::AsTensor<int32>({1, 2, 3, 400}, {{2, 2}});
1536 op.input_tensors[1] = &croppings;
1537 INFER_ERROR("Negative dimension size caused by subtracting", op,
1538 "[4,8,8,3];[2,2]");
1539
1540 // Negative paddings
1541 croppings = test::AsTensor<int32>({1, -2, 3, 4}, {{2, 2}});
1542 op.input_tensors[1] = &croppings;
1543 INFER_ERROR("cannot be negative", op, "[4,8,8,3];[2,2]");
1544 }
1545
TEST(ArrayOpsTest,BatchToSpaceND_ShapeFn)1546 TEST(ArrayOpsTest, BatchToSpaceND_ShapeFn) {
1547 ShapeInferenceTestOp op("BatchToSpaceND");
1548 op.input_tensors.resize(3);
1549 TF_ASSERT_OK(NodeDefBuilder("test", "BatchToSpaceND")
1550 .Input("input", 0, DT_FLOAT)
1551 .Input("block_shape", 1, DT_INT32)
1552 .Input("crops", 2, DT_INT32)
1553 .Finalize(&op.node_def));
1554
1555 // Verify that input shape and crops shape can be unknown.
1556 INFER_OK(op, "?;[2];?", "?");
1557
1558 // Only number of input dimensions is known.
1559 INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1560
1561 {
1562 // Dimensions are partially known, block_shape known.
1563 Tensor block_shape = test::AsTensor<int32>({2, 3});
1564 op.input_tensors[1] = &block_shape;
1565 INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1566
1567 INFER_OK(op, "[18,?,?,2];[2];?", "[3,?,?,d0_3]");
1568
1569 // Dimensions are partially known, block_shape and crops known.
1570 {
1571 Tensor crops = test::AsTensor<int32>({1, 1, 0, 1}, {{2, 2}});
1572 op.input_tensors[2] = &crops;
1573 INFER_OK(op, "[18,?,2,2];[2];[2,2]", "[3,?,5,d0_3]");
1574 op.input_tensors[2] = nullptr;
1575 }
1576
1577 // Dimensions are fully known, block_shape and crops are known.
1578 {
1579 Tensor crops = test::AsTensor<int32>({1, 1, 0, 0}, {{2, 2}});
1580 op.input_tensors[2] = &crops;
1581 INFER_OK(op, "[18,2,1,2];[2];[2,2]", "[3,2,3,d0_3]");
1582 op.input_tensors[2] = nullptr;
1583 }
1584
1585 op.input_tensors[1] = nullptr;
1586 }
1587
1588 INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1589 INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1590 INFER_ERROR("rank", op, "[2,2];[2];[2,2]");
1591 INFER_ERROR("rank", op, "[2,2,3];[3];[3,2]");
1592
1593 {
1594 Tensor block_shape = test::AsTensor<int32>({0, 2});
1595 op.input_tensors[1] = &block_shape;
1596 INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1597 op.input_tensors[1] = nullptr;
1598 }
1599
1600 {
1601 Tensor block_shape = test::AsTensor<int32>({1, 1});
1602 op.input_tensors[1] = &block_shape;
1603 Tensor paddings = test::AsTensor<int32>({0, -1, 0, 0}, {{2, 2}});
1604 op.input_tensors[2] = &paddings;
1605 INFER_ERROR("crops cannot be negative", op, "[1,2,2];[2];[2,2]");
1606 op.input_tensors[1] = nullptr;
1607 op.input_tensors[2] = nullptr;
1608 }
1609
1610 // The amount to crop exceeds the padded size.
1611 {
1612 Tensor block_shape = test::AsTensor<int32>({2, 2});
1613 op.input_tensors[1] = &block_shape;
1614 Tensor crops = test::AsTensor<int32>({3, 2, 0, 0}, {{2, 2}});
1615 op.input_tensors[2] = &crops;
1616 INFER_ERROR("Negative", op, "[4,2,3,1];[2];[2,2]");
1617 op.input_tensors[1] = nullptr;
1618 op.input_tensors[2] = nullptr;
1619 }
1620
1621 // The batch size is not divisible by the product of the block_shape.
1622 {
1623 Tensor block_shape = test::AsTensor<int32>({2, 3});
1624 op.input_tensors[1] = &block_shape;
1625 INFER_ERROR("divisible", op, "[3,1,1,1];[2];[2,2]");
1626 op.input_tensors[1] = nullptr;
1627 }
1628 }
1629
TEST(ArrayOpsTest,SpaceToDepth_ShapeFn)1630 TEST(ArrayOpsTest, SpaceToDepth_ShapeFn) {
1631 ShapeInferenceTestOp op("SpaceToDepth");
1632 TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToDepth")
1633 .Input("input", 0, DT_FLOAT)
1634 .Attr("block_size", 2)
1635 .Finalize(&op.node_def));
1636
1637 INFER_OK(op, "[1,2,4,4]", "[d0_0,1,2,16]");
1638
1639 // block_size not compatible with space
1640 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 3", op,
1641 "[1,3,8,4]");
1642 INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1643 "[1,2,5,4]");
1644
1645 // Unknown depth --> Unknown depth.
1646 INFER_OK(op, "[1,2,4,?]", "[d0_0,1,2,?]");
1647 }
1648
TEST(ArrayOpsTest,DepthToSpace_ShapeFn)1649 TEST(ArrayOpsTest, DepthToSpace_ShapeFn) {
1650 ShapeInferenceTestOp op("DepthToSpace");
1651 TF_ASSERT_OK(NodeDefBuilder("test", "DepthToSpace")
1652 .Input("input", 0, DT_FLOAT)
1653 .Attr("block_size", 2)
1654 .Finalize(&op.node_def));
1655
1656 INFER_OK(op, "[1,1,2,16]", "[d0_0,2,4,4]");
1657
1658 // Bad depth
1659 INFER_ERROR("Dimension size must be evenly divisible by 4 but is 15", op,
1660 "[1,1,2,15]");
1661
1662 // Unknown depth --> Unknown depth.
1663 INFER_OK(op, "[1,2,4,?]", "[d0_0,4,8,?]");
1664
1665 // Check another block size.
1666 TF_ASSERT_OK(NodeDefBuilder("test", "DepthToSpace")
1667 .Input("input", 0, DT_FLOAT)
1668 .Attr("block_size", 10)
1669 .Finalize(&op.node_def));
1670 INFER_OK(op, "[1,1,2,200]", "[d0_0,10,20,2]");
1671 }
1672
TEST(ArrayOpsTest,Slice_ShapeFn)1673 TEST(ArrayOpsTest, Slice_ShapeFn) {
1674 ShapeInferenceTestOp op("Slice");
1675 TF_ASSERT_OK(NodeDefBuilder("test", "Slice")
1676 .Input("input", 0, DT_FLOAT)
1677 .Input("begin", 1, DT_INT64)
1678 .Input("sizes", 2, DT_INT64)
1679 .Finalize(&op.node_def));
1680
1681 // Known rank of input and shape of begin/sizes, but unknown values.
1682 // The best we know is the rank of the output.
1683 INFER_OK(op, "[2,3,4,5];[4];[4]", "[?,?,?,?]");
1684
1685 // Unknown shape of begin/sizes, we still know the rank.
1686 INFER_OK(op, "[2,3,4,5];[?];[?]", "[?,?,?,?]");
1687 // Unknown all around
1688 INFER_OK(op, "?;[?];[?]", "?");
1689 // Can infer based on begin
1690 INFER_OK(op, "?;[4];[?]", "[?,?,?,?]");
1691
1692 // Bad rank of begin, sizes
1693 INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2,3];[3]");
1694 INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2];[3,4]");
1695 // Length of begin doesn't match input rank
1696 INFER_ERROR("must be rank 2", op, "[2,3,4,5];[2];[2]");
1697
1698 // Tests with known values.
1699 op.input_tensors.resize(3);
1700 Tensor begin = test::AsTensor<int32>({0, 1, 2, 1});
1701 Tensor sizes = test::AsTensor<int32>({1, 2, 1, 3});
1702 op.input_tensors[1] = &begin;
1703 op.input_tensors[2] = &sizes;
1704 INFER_OK(op, "[2,3,4,5];[4];[4]", "[1,2,1,3]");
1705
1706 // -1 in sizes means "get the rest"
1707 sizes = test::AsTensor<int32>({-1, -1, 1, -1});
1708 INFER_OK(op, "[2,3,4,5];[4];[4]", "[d0_0,2,1,4]");
1709
1710 begin = test::AsTensor<int32>({0, 1, 2, 6});
1711 sizes = test::AsTensor<int32>({-1, -1, -1, -1});
1712 INFER_ERROR("Negative dimension size", op, "[2,3,4,5];[4];[4]");
1713
1714 begin = test::AsTensor<int32>({0, 1, 2, 5});
1715 sizes = test::AsTensor<int32>({-1, -1, -1, -2});
1716 INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
1717 }
1718
TEST(ArrayOpsTest,StridedSlice_ShapeFn)1719 TEST(ArrayOpsTest, StridedSlice_ShapeFn) {
1720 ShapeInferenceTestOp op("StridedSlice");
1721 TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice")
1722 .Input("input", 0, DT_FLOAT)
1723 .Input("begin", 1, DT_INT32)
1724 .Input("end", 2, DT_INT32)
1725 .Input("strides", 3, DT_INT32)
1726 .Attr("shrink_axis_mask", 1)
1727 .Finalize(&op.node_def));
1728 op.input_tensors.resize(4);
1729 Tensor strides = test::AsTensor<int32>({1});
1730 op.input_tensors[3] = &strides;
1731 // Slicing on the 0-th dimension.
1732 INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]");
1733 // Slicing on the 0-th dimension. This time some of the result dimension is 0.
1734 INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]");
1735 }
1736
TEST(ArrayOpsTest,StridedSliceGrad_ShapeFn)1737 TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
1738 ShapeInferenceTestOp op("StridedSliceGrad");
1739 op.input_tensors.resize(5);
1740 INFER_OK(op, "?;?;?;?;?", "?");
1741 INFER_OK(op, "[?];?;?;?;?", "?");
1742 INFER_OK(op, "[4];?;?;?;?", "[?,?,?,?]");
1743
1744 Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
1745 op.input_tensors[0] = &in_t;
1746 INFER_OK(op, "[4];?;?;?;?", "[1,2,3,4]");
1747 }
1748
TEST(ArrayOpsTest,UnchangedWithQuantizationScalars_ShapeFn)1749 TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) {
1750 for (const char* op_name : {"Dequantize", "FakeQuantWithMinMaxVars"}) {
1751 ShapeInferenceTestOp op(op_name);
1752 if (op_name[0] == 'D') {
1753 TF_ASSERT_OK(NodeDefBuilder("test", "Dequantize")
1754 .Input("input", 0, DT_QINT8)
1755 .Input("input_min", 1, DT_FLOAT)
1756 .Input("input_max", 2, DT_FLOAT)
1757 .Attr("T", DataTypeToEnum<qint8>::v())
1758 .Attr("mode", "SCALED")
1759 .Attr("axis", -1)
1760 .Finalize(&op.node_def));
1761 }
1762 INFER_OK(op, "?;?;?", "in0");
1763 INFER_OK(op, "[1,?,3];[];[]", "in0");
1764
1765 // Rank check scalars.
1766 INFER_ERROR("be rank 0", op, "[1,?,3];[1];[]");
1767 INFER_ERROR("be rank 0", op, "[1,?,3];[];[1]");
1768 }
1769 }
1770
TEST(ArrayOpsTest,FakeQuantWithMinMaxVarsPerChannel)1771 TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
1772 ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
1773
1774 INFER_OK(op, "?;?;?", "in0");
1775 INFER_OK(op, "[?];?;?", "in0");
1776 INFER_OK(op, "[1,?,3];[3];[3]", "in0");
1777 INFER_OK(op, "[3];[3];[3]", "in0");
1778
1779 // Rank check vectors.
1780 INFER_ERROR("be rank 1", op, "[1,?,3];[1];[]");
1781 INFER_ERROR("be rank 1", op, "[1,?,3];[];[1]");
1782
1783 // Vectors must match each other, and match last dim of input.
1784 INFER_ERROR("must be equal", op, "[1,?,3];[2];[?]");
1785 INFER_ERROR("must be equal", op, "[1,?,3];[?];[2]");
1786 INFER_ERROR("must be equal", op, "[1,?,?];[1];[2]");
1787 INFER_ERROR("must be equal", op, "[5];[4];[?]");
1788 }
1789
TEST(ArrayOpsTest,FakeQuantWithMinMaxVarsPerChannelGradient)1790 TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
1791 ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
1792
1793 INFER_OK(op, "?;?;?;?", "in0;[?];[?]");
1794 INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
1795 INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
1796 INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
1797
1798 // Rank check vectors.
1799 INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
1800 INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
1801 INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
1802 INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
1803
1804 // Vectors must match each other, and match last dim of input.
1805 INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
1806 INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
1807 }
1808
TEST(ArrayOpsTest,QuantizedConcat_ShapeFn)1809 TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) {
1810 ShapeInferenceTestOp op("QuantizedConcat");
1811 auto set_n = [&op](int n) {
1812 std::vector<NodeDefBuilder::NodeOut> src_list;
1813 std::vector<NodeDefBuilder::NodeOut> limit_list;
1814 for (int i = 0; i < n; ++i) {
1815 src_list.emplace_back("a", 0, DT_QUINT8);
1816 limit_list.emplace_back("b", 0, DT_FLOAT);
1817 }
1818 TF_ASSERT_OK(NodeDefBuilder("test", "QuantizedConcat")
1819 .Input({"concat_dim", 0, DT_INT32})
1820 .Input(src_list)
1821 .Input(limit_list)
1822 .Input(limit_list)
1823 .Attr("N", n)
1824 .Finalize(&op.node_def));
1825 };
1826
1827 // Confirm dimension[0] of the input (the concat_dim) is a scalar.
1828 set_n(1);
1829 INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?;?");
1830
1831 // Last 2*<N> are all scalars.
1832 set_n(2);
1833 INFER_ERROR("must be rank 0", op, "[];?;?;?;?;?;[1]");
1834 INFER_ERROR("must be rank 0", op, "[];?;?;?;?;[1];?");
1835 INFER_ERROR("must be rank 0", op, "[];?;?;?;[1];?;?");
1836 INFER_ERROR("must be rank 0", op, "[];?;?;[1];?;?;?");
1837
1838 // First is concat dim; next N must be compatible for concat.
1839 set_n(2);
1840 INFER_ERROR("must be rank 2", op, "[];[1,2];[1,2,3];?;?;?;?");
1841 INFER_OK(op, "[];[1,2];[1,3];?;?;?;?", "[?,?];[];[]");
1842
1843 // Test when the concat_dim tensor is known. The concatenated dimension is
1844 // summed across all input tensors, and other dimensions are merged.
1845 Tensor concat_dim_t;
1846 op.input_tensors.push_back(&concat_dim_t);
1847 set_n(2);
1848 concat_dim_t = test::AsScalar(0); // Sum dim 0, merge the other two dims.
1849 INFER_OK(op, "[];[100,2,?];[10,?,3];?;?;?;?", "[110,d1_1,d2_2];[];[]");
1850 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
1851 "[];[100,2,5];[10,?,3];?;?;?;?");
1852 // Note that other cases of concat are covered in the Concat tests.
1853 }
1854
TEST(StateOpsTest,_ParallelConcatStart_ShapeFn)1855 TEST(StateOpsTest, _ParallelConcatStart_ShapeFn) {
1856 ShapeInferenceTestOp op("_ParallelConcatStart");
1857 TensorShape shape({1, 2, 3});
1858 TensorShapeProto shape_proto;
1859 shape.AsProto(&shape_proto);
1860 TF_ASSERT_OK(NodeDefBuilder("test", "_ParallelConcatStart")
1861 .Attr("shape", shape_proto)
1862 .Attr("dtype", DT_FLOAT)
1863 .Finalize(&op.node_def));
1864 INFER_OK(op, "", "[1,2,3]");
1865 }
1866
1867 } // end namespace tensorflow
1868