• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_builder.h"
17 #include "tensorflow/core/framework/node_def_util.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/framework/shape_inference_testutil.h"
21 #include "tensorflow/core/framework/tensor.pb.h"
22 #include "tensorflow/core/framework/tensor_shape.pb.h"
23 #include "tensorflow/core/framework/tensor_testutil.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/public/version.h"
27 
28 namespace tensorflow {
29 
TEST(ArrayOpsTest,TensorScatterUpdate_ShapeFn)30 TEST(ArrayOpsTest, TensorScatterUpdate_ShapeFn) {
31   ShapeInferenceTestOp op("TensorScatterUpdate");
32 
33   INFER_OK(op, "[4,3];[8,2];[8]", "in0");
34   INFER_OK(op, "[?,?];[?,2];[?]", "in0");
35   INFER_OK(op, "[?];[?];[?]", "in0");
36 
37   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
38               "[];[?,2];[?]");
39   INFER_ERROR("Indices and updates specified for empty input", op,
40               "[0,2,2];[8,2];[8]");
41   INFER_ERROR(
42       "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
43       "dimensions [0,1) of updates[shape=[9]] = [9]",
44       op, "[?,?];[8,2];[9]");
45   INFER_ERROR(
46       "Dimensions [2,2) of input[shape=[?,?]] = [] must match "
47       "dimensions [1,2) of updates[shape=[?,1]] = [1]",
48       op, "[?,?];[?,2];[?,1]");
49 }
50 
TEST(ArrayOpsTest,ScatterNd_ShapeFn)51 TEST(ArrayOpsTest, ScatterNd_ShapeFn) {
52   ShapeInferenceTestOp op("ScatterNd");
53 
54   INFER_OK(op, "[8,2];[8];[2]", "[?,?]");
55 
56   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[?,2];[?];[]");
57   INFER_ERROR(
58       "Dimensions [0,1) of indices[shape=[8,2]] = [8] must match "
59       "dimensions [0,1) of updates[shape=[9]] = [9]",
60       op, "[8,2];[9];[?]");
61 }
62 
TEST(ArrayOpsTest,UnravelIndex_ShapeFn)63 TEST(ArrayOpsTest, UnravelIndex_ShapeFn) {
64   ShapeInferenceTestOp op("UnravelIndex");
65 
66   INFER_OK(op, "?;?", "?");
67 
68   INFER_OK(op, "[];[?]", "[d1_0]");
69 
70   INFER_OK(op, "[4,5];[?]", "[d1_0,20]");
71   INFER_OK(op, "[2,3,4];[?]", "[d1_0,24]");
72   INFER_OK(op, "?;[?]", "?");
73   INFER_OK(op, "[?];[?]", "[d1_0,?]");
74 
75   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,1]");
76 }
77 
TEST(ArrayOpsTest,Pack_ShapeFn)78 TEST(ArrayOpsTest, Pack_ShapeFn) {
79   ShapeInferenceTestOp op("Pack");
80   auto set_axis = [&op](int axis) {
81     int n = 3;
82     std::vector<NodeDefBuilder::NodeOut> src_list;
83     src_list.reserve(n);
84     for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
85     TF_ASSERT_OK(NodeDefBuilder("test", "Pack")
86                      .Input(src_list)
87                      .Attr("N", n)
88                      .Attr("axis", axis)
89                      .Finalize(&op.node_def));
90   };
91 
92   set_axis(0);
93   INFER_OK(op, "?;?;?", "?");
94 
95   for (int axis : {0, -3}) {
96     set_axis(axis);
97     INFER_OK(op, "?;?;?", "?");
98     INFER_OK(op, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
99     INFER_OK(op, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
100     INFER_OK(op, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
101   }
102   for (int axis : {1, -2}) {
103     set_axis(axis);
104     INFER_OK(op, "?;?;?", "?");
105     INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
106     INFER_OK(op, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
107     INFER_OK(op, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
108   }
109   for (int axis : {2, -1}) {
110     set_axis(axis);
111     INFER_OK(op, "?;?;?", "?");
112     INFER_OK(op, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
113     INFER_OK(op, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
114     INFER_OK(op, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
115   }
116 
117   set_axis(-4);
118   INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,3];[1,3];?");
119   set_axis(3);
120   INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,3];[1,3];?");
121 
122   set_axis(0);
123 
124   // Check that both components of error message are there.
125   INFER_ERROR("Shapes must be equal rank, but are 3 and 2", op,
126               "[1,2,3];?;[1,4]");
127   INFER_ERROR("From merging shape 0 with other shapes.", op, "[1,2,3];?;[1,4]");
128 }
129 
TEST(ArrayOpsTest,UnPack_ShapeFn)130 TEST(ArrayOpsTest, UnPack_ShapeFn) {
131   ShapeInferenceTestOp op("Unpack");
132   auto set_axis_and_num = [&op](int axis, int num) {
133     TF_ASSERT_OK(NodeDefBuilder("test", "Unpack")
134                      .Input("a", 0, DT_FLOAT)
135                      .Attr("axis", axis)
136                      .Attr("num", num)
137                      .Finalize(&op.node_def));
138   };
139 
140   set_axis_and_num(0, 1);
141   INFER_OK(op, "?", "?");
142 
143   for (int axis : {0, -3}) {
144     set_axis_and_num(axis, 1);
145     INFER_OK(op, "?", "?");
146     INFER_OK(op, "[1,2,3]", "[d0_1,d0_2]");
147     INFER_OK(op, "[?,?,?]", "[d0_1,d0_2]");
148   }
149   for (int axis : {1, -2}) {
150     set_axis_and_num(axis, 2);
151     INFER_OK(op, "[1,2,3]", "[d0_0,d0_2];[d0_0,d0_2]");
152     INFER_OK(op, "[?,?,?]", "[d0_0,d0_2];[d0_0,d0_2]");
153   }
154   for (int axis : {2, -1}) {
155     set_axis_and_num(axis, 3);
156     INFER_OK(op, "[1,2,3]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
157     INFER_OK(op, "[?,?,?]", "[d0_0,d0_1];[d0_0,d0_1];[d0_0,d0_1]");
158   }
159 
160   set_axis_and_num(2, 2);
161   INFER_ERROR("Dimension must be 2 but is 3", op, "[1,2,3]");
162 
163   set_axis_and_num(-4, 3);
164   INFER_ERROR("Invalid axis: -4; must be in [-3,3)", op, "[1,2,3]");
165   set_axis_and_num(3, 3);
166   INFER_ERROR("Invalid axis: 3; must be in [-3,3)", op, "[1,2,3]");
167 }
168 
TEST(ArrayOpsTest,Const_ShapeFn)169 TEST(ArrayOpsTest, Const_ShapeFn) {
170   ShapeInferenceTestOp op("Const");
171   TensorProto tensor_proto;
172   auto* shape_proto = tensor_proto.mutable_tensor_shape();
173   auto rebuild_node_def = [&op, &tensor_proto]() {
174     TF_ASSERT_OK(NodeDefBuilder("test", "Const")
175                      .Attr("value", tensor_proto)
176                      .Finalize(&op.node_def));
177   };
178 
179   TensorShape{}.AsProto(shape_proto);
180   rebuild_node_def();
181   INFER_OK(op, "", "[]");
182   TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
183   rebuild_node_def();
184   INFER_OK(op, "", "[1,2,3,4]");
185 
186   shape_proto->add_dim()->set_size(-1);
187   rebuild_node_def();
188   INFER_ERROR("Shape [1,2,3,4,?] is not fully defined", op, "");
189 }
190 
TEST(ArrayOpsTest,UnchangedShapes_ShapeFn)191 TEST(ArrayOpsTest, UnchangedShapes_ShapeFn) {
192   for (const char* op_name : {
193            "CheckNumerics",
194            "Identity",
195            "RefIdentity",
196            "QuantizeAndDequantize",
197            "StopGradient",
198            "ZerosLike",
199            "OnesLike",
200        }) {
201     ShapeInferenceTestOp op(op_name);
202     INFER_OK(op, "?", "in0");
203     INFER_OK(op, "[]", "in0");
204     INFER_OK(op, "[1,2,?,4,5]", "in0");
205   }
206 
207   // inputs 1 and 2 are ignored; input 0 is transferred to output 0.
208   ShapeInferenceTestOp op("MatrixBandPart");
209   INFER_OK(op, "?;?;?", "in0");
210   INFER_OK(op, "[];?;?", "in0");
211   INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
212 }
213 
TEST(ArrayOpsTest,GuaranteeConst_ShapeFn)214 TEST(ArrayOpsTest, GuaranteeConst_ShapeFn) {
215   ShapeInferenceTestOp op("GuaranteeConst");
216   INFER_OK(op, "?", "in0");
217   INFER_OK(op, "[]", "in0");
218   INFER_OK(op, "[1,2,?,4,5]", "in0");
219 }
220 
TEST(ArrayOpsTest,Identity_ShapeFnHandles)221 TEST(ArrayOpsTest, Identity_ShapeFnHandles) {
222   const char* op_name = "Identity";
223   ShapeInferenceTestOp op(op_name);
224   // Check that handle dtypes are preserved.
225   const OpRegistrationData* op_reg_data;
226   TF_ASSERT_OK(OpRegistry::Global()->LookUp(op.name, &op_reg_data));
227   std::vector<
228       std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>
229       handle_data;
230   handle_data.emplace_back(
231       new std::vector<std::pair<PartialTensorShape, DataType>>(
232           {{PartialTensorShape(), DT_BOOL}}));
233   shape_inference::InferenceContext c(
234       TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def,
235       {PartialTensorShape()}, {}, {}, handle_data);
236   TF_ASSERT_OK(c.construction_status());
237   ASSERT_TRUE(op_reg_data->shape_inference_fn != nullptr);
238   TF_ASSERT_OK(c.Run(op_reg_data->shape_inference_fn));
239 
240   const auto* shapes_and_types = c.output_handle_shapes_and_types(0);
241   ASSERT_TRUE(shapes_and_types != nullptr);
242   ASSERT_EQ(1, shapes_and_types->size());
243   EXPECT_EQ((*shapes_and_types)[0].dtype, DT_BOOL);
244 }
245 
TEST(ArrayOpsTest,Diag_ShapeFn)246 TEST(ArrayOpsTest, Diag_ShapeFn) {
247   ShapeInferenceTestOp op("Diag");
248   INFER_OK(op, "?", "?");
249   INFER_OK(op, "[1,?,3]", "[d0_0,d0_1,d0_2,d0_0,d0_1,d0_2]");
250   INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2,d0_3,d0_0,d0_1,d0_2,d0_3]");
251   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
252 }
253 
TEST(ArrayOpsTest,DiagPart_ShapeFn)254 TEST(ArrayOpsTest, DiagPart_ShapeFn) {
255   ShapeInferenceTestOp op("DiagPart");
256   INFER_OK(op, "?", "?");
257   INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_3]");
258   INFER_OK(op, "[1,?,3,?,4,3]", "[d0_0,d0_4,d0_2|d0_5]");
259   INFER_OK(op, "[1,2,3,?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_7]");
260   INFER_ERROR("Input must have even and non-zero rank", op, "[]");
261   INFER_ERROR("Input must have even and non-zero rank", op, "[?]");
262   INFER_ERROR("Input must have even and non-zero rank", op, "[1,2,3]");
263   INFER_ERROR("Dimensions must be equal, but are 2 and 10", op, "[1,2,?,10]");
264 }
265 
TEST(ArrayOpsTest,MatrixDiag_ShapeFn)266 TEST(ArrayOpsTest, MatrixDiag_ShapeFn) {
267   ShapeInferenceTestOp op("MatrixDiag");
268   INFER_OK(op, "?", "?");
269   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
270   INFER_OK(op, "[?]", "[d0_0,d0_0]");
271   INFER_OK(op, "[1,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,d0_3]");
272 }
273 
TEST(ArrayOpsTest,MatrixDiagPart_ShapeFn)274 TEST(ArrayOpsTest, MatrixDiagPart_ShapeFn) {
275   ShapeInferenceTestOp op("MatrixDiagPart");
276   INFER_OK(op, "?", "?");
277   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[?]");
278   INFER_OK(op, "[?,1,2,2]", "[d0_0,d0_1,d0_2|d0_3]");
279   INFER_OK(op, "[?,1,2,3]", "[d0_0,d0_1,d0_2]");
280   INFER_OK(op, "[?,1,3,2]", "[d0_0,d0_1,d0_3]");
281 }
282 
TEST(ArrayOpsTest,Reverse_ShapeFn)283 TEST(ArrayOpsTest, Reverse_ShapeFn) {
284   ShapeInferenceTestOp op("Reverse");
285   INFER_OK(op, "?;?", "in0");
286   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
287   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
288   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4]");
289   INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
290               op, "[1,2,3,4,5,6,7,8,9];[9]");
291   INFER_OK(op, "[1,2,3,?];[4]", "in0");
292   INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
293 }
294 
TEST(ArrayOpsTest,ReverseV2_ShapeFn)295 TEST(ArrayOpsTest, ReverseV2_ShapeFn) {
296   ShapeInferenceTestOp op("ReverseV2");
297   INFER_OK(op, "?;?", "in0");
298   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
299   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,2]");
300   INFER_OK(op, "[1,2,3];[2]", "in0");
301   INFER_ERROR("reverse does not work on tensors with more than 8 dimensions",
302               op, "[1,2,3,4,5,6,7,8,9];[9]");
303   INFER_OK(op, "[1,2,3,?];[4]", "in0");
304   INFER_OK(op, "[1,2,3,?,5,6,7,8];[8]", "in0");
305 }
306 
TEST(ArrayOpsTest,Fill_ShapeFn)307 TEST(ArrayOpsTest, Fill_ShapeFn) {
308   ShapeInferenceTestOp op("Fill");
309   AddNodeAttr("index_type", DT_INT32, &op.node_def);
310   op.input_tensors.resize(2);
311   INFER_OK(op, "?;?", "?");
312   INFER_OK(op, "[?];?", "?");
313   INFER_OK(op, "[4];?", "[?,?,?,?]");
314 
315   Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
316   op.input_tensors[0] = &in_t;
317   INFER_OK(op, "[4];?", "[1,2,3,4]");
318 }
319 
TEST(ArrayOpsTest,Gather_ShapeFn)320 TEST(ArrayOpsTest, Gather_ShapeFn) {
321   ShapeInferenceTestOp op("Gather");
322   INFER_OK(op, "?;?", "?");
323   INFER_OK(op, "[1,?,2];[3]", "[d1_0,d0_1,d0_2]");
324   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[1,2,3]");
325 }
326 
TEST(ArrayOpsTest,GatherV2_ShapeFn)327 TEST(ArrayOpsTest, GatherV2_ShapeFn) {
328   ShapeInferenceTestOp op("GatherV2");
329   AddNodeAttr("batch_dims", 0, &op.node_def);
330 
331   // Tests when axis is unknown.
332   INFER_OK(op, "?;?;?", "?");
333   INFER_OK(op, "[1,2,3];[3];[]", "[?,?,?]");
334   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op,
335               "[];[1,2,3];[]");
336 
337   // Non-scalar axis.
338   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];[1,2,3];[1]");
339 
340   // Test when axis dim is known.
341   Tensor axis_dim_t;
342   op.input_tensors.resize(3);
343   op.input_tensors[2] = &axis_dim_t;
344 
345   // Out of range axis.
346   axis_dim_t = test::AsScalar(1);
347   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
348               "[1];[1,2];[]");
349 
350   // Rank 0 indices.
351   axis_dim_t = test::AsScalar(0);
352   INFER_OK(op, "[1,2,3];[];[]", "[d0_1,d0_2]");
353   axis_dim_t = test::AsScalar(1);
354   INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_2]");
355   axis_dim_t = test::AsScalar(2);
356   INFER_OK(op, "[1,2,3];[];[]", "[d0_0,d0_1]");
357 
358   // Rank 1 indices.
359   axis_dim_t = test::AsScalar(0);
360   INFER_OK(op, "[1,2,3];[5];[]", "[d1_0,d0_1,d0_2]");
361   axis_dim_t = test::AsScalar(1);
362   INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d1_0,d0_2]");
363   axis_dim_t = test::AsScalar(2);
364   INFER_OK(op, "[1,2,3];[5];[]", "[d0_0,d0_1,d1_0]");
365 
366   // Rank 2 indices.
367   axis_dim_t = test::AsScalar(0);
368   INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
369   axis_dim_t = test::AsScalar(1);
370   INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
371   axis_dim_t = test::AsScalar(2);
372   INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
373 
374   // Negative axis.
375   axis_dim_t = test::AsScalar(-3);
376   INFER_OK(op, "[1,2,3];[5,6];[]", "[d1_0,d1_1,d0_1,d0_2]");
377   axis_dim_t = test::AsScalar(-2);
378   INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d1_0,d1_1,d0_2]");
379   axis_dim_t = test::AsScalar(-1);
380   INFER_OK(op, "[1,2,3];[5,6];[]", "[d0_0,d0_1,d1_0,d1_1]");
381 
382   // Batch dimensions > 0
383   // Create another node since we can't overwrite the original batch dims.
384   ShapeInferenceTestOp batch_op("GatherV2");
385   AddNodeAttr("batch_dims", 1, &batch_op.node_def);
386   INFER_OK(batch_op, "[1,4800,8];[1,28400];[]", "[?,?,?]");
387 
388   ShapeInferenceTestOp batch_op_2("GatherV2");
389   AddNodeAttr("batch_dims", 2, &batch_op_2.node_def);
390   INFER_OK(batch_op_2, "[1,2,3,4,5];[1,2,3];[]", "[?,?,?,?,?]");
391 }
392 
TEST(ArrayOpsTest,GatherNd_ShapeFn)393 TEST(ArrayOpsTest, GatherNd_ShapeFn) {
394   ShapeInferenceTestOp op("GatherNd");
395 
396   // Inputs are (params, indices).
397   INFER_OK(op, "?;?", "?");
398   INFER_OK(op, "[1,?,3,?];[?,0]", "[d1_0,d0_0,d0_1,d0_2,d0_3]");
399   INFER_OK(op, "[1,?,3,?];[?,4]", "[d1_0]");
400 
401   // params.rank >= indices.dim(-1).
402   INFER_ERROR("indices.shape[-1] must be <= params.rank", op, "[1,2,3];[4]");
403 }
404 
TEST(ArrayOpsTest,Shape_ShapeFn)405 TEST(ArrayOpsTest, Shape_ShapeFn) {
406   ShapeInferenceTestOp op("Shape");
407   INFER_OK(op, "?", "[?]");
408   INFER_OK(op, "[?]", "[1]");
409   INFER_OK(op, "[?,2,3,4,5]", "[5]");
410 }
411 
TEST(ArrayOpsTest,ShapeN_ShapeFn)412 TEST(ArrayOpsTest, ShapeN_ShapeFn) {
413   ShapeInferenceTestOp op("ShapeN");
414   int n = 3;
415   std::vector<NodeDefBuilder::NodeOut> src_list;
416   src_list.reserve(n);
417   for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
418   TF_ASSERT_OK(NodeDefBuilder("test", "ShapeN")
419                    .Input(src_list)
420                    .Attr("N", n)
421                    .Finalize(&op.node_def));
422   INFER_OK(op, "?;?;?", "[?];[?];[?]");
423   INFER_OK(op, "[?];[?];[?]", "[1];[1];[1]");
424   INFER_OK(op, "[?,2,3,4,5];?;[1,?,3]", "[5];[?];[3]");
425 }
426 
TEST(ArrayOpsTest,Unique_ShapeFn)427 TEST(ArrayOpsTest, Unique_ShapeFn) {
428   ShapeInferenceTestOp op("Unique");
429   INFER_OK(op, "?", "[?];in0");
430   INFER_OK(op, "[5]", "[?];in0");
431   INFER_ERROR("Shape must be rank 1 but is rank 5", op, "[1,2,3,?,5]");
432 }
433 
TEST(ArrayOpsTest,UniqueWithCounts_ShapeFn)434 TEST(ArrayOpsTest, UniqueWithCounts_ShapeFn) {
435   ShapeInferenceTestOp op("UniqueWithCounts");
436   INFER_OK(op, "?", "[?];in0;[?]");
437   INFER_OK(op, "[1,2,3,?,5]", "[?];in0;[?]");
438 }
439 
TEST(ArrayOpsTest,InvertPermutation_ShapeFn)440 TEST(ArrayOpsTest, InvertPermutation_ShapeFn) {
441   ShapeInferenceTestOp op("InvertPermutation");
442   INFER_OK(op, "?", "[?]");
443   INFER_OK(op, "[1]", "in0");
444   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[]");
445 }
446 
TEST(ArrayOpsTest,PadD_ShapeFn)447 TEST(ArrayOpsTest, PadD_ShapeFn) {
448   for (const char* op_name : {"Pad", "MirrorPad"}) {
449     ShapeInferenceTestOp op(op_name);
450     op.input_tensors.resize(2);
451 
452     // Inputs are input and paddings.
453 
454     INFER_OK(op, "?;?", "?");
455 
456     // Check shape of paddings.
457     INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3]");
458     INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4]");
459 
460     // input.rank and paddings.dim(0) are equal. This is the number of dims in
461     // output.
462     INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2]");
463     INFER_OK(op, "[1,2,3];?", "[?,?,?]");
464     INFER_OK(op, "?;[3,2]", "[?,?,?]");
465 
466     // Make the paddings tensor known and verify padding values get added.
467     // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
468     // to input dims to get output.
469     Tensor paddings_t(DT_INT64, TensorShape{3, 2});
470     test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
471     op.input_tensors[1] = &paddings_t;
472     INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
473     INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
474     INFER_OK(op, "?;[3,2]", "[?,?,?]");
475     INFER_OK(op, "?;?", "[?,?,?]");
476   }
477 }
478 
TEST(ArrayOpsTest,PadV2_ShapeFn)479 TEST(ArrayOpsTest, PadV2_ShapeFn) {
480   ShapeInferenceTestOp op("PadV2");
481   op.input_tensors.resize(3);
482 
483   // Inputs are input, paddings and constant_values.
484 
485   INFER_OK(op, "?;?;?", "?");
486 
487   // Check shape of paddings.
488   INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;[1,2,3];?");
489   INFER_ERROR("Dimension must be 2 but is 4", op, "?;[1,4];?");
490 
491   // input.rank and paddings.dim(0) are equal. This is the number of dims in
492   // output.
493   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];[4,2];[]");
494   INFER_OK(op, "[1,2,3];?;[]", "[?,?,?]");
495   INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
496 
497   // Make the paddings tensor known and verify padding values get added.
498   // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
499   // to input dims to get output.
500   Tensor paddings_t(DT_INT64, TensorShape{3, 2});
501   test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
502   op.input_tensors[1] = &paddings_t;
503   INFER_OK(op, "[100,200,300];[3,2];[]", "[111,222,333]");
504   INFER_OK(op, "[100,?,300];[3,2];[]", "[111,?,333]");
505   INFER_OK(op, "?;[3,2];[]", "[?,?,?]");
506   INFER_OK(op, "?;?;[]", "[?,?,?]");
507 }
508 
TEST(ArrayOpsTest,MirrorPadGrad_ShapeFn)509 TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) {
510   ShapeInferenceTestOp op("MirrorPadGrad");
511   op.input_tensors.resize(2);
512 
513   // Inputs are input and paddings.
514   INFER_OK(op, "?;?", "?");
515 
516   // First padding dimension is unknown, so rank is unknown.
517   INFER_OK(op, "?;[?,4]", "?");
518 
519   // Input tensor rank doesn't match paddings dimension.
520   INFER_ERROR("must be rank 3 but is rank 2", op, "[?,?];[3,2]");
521 
522   // Paddings tensor is not a [rank x 2] matrix.
523   INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
524               "[?,?,?];[3,3]");
525 
526   // Paddings tensor is unknown, but rank is known, so the output
527   // shape is a rank 3 unknown shape.
528   INFER_OK(op, "[?,?,?];[3,2]", "[?,?,?]");
529 
530   // Make the paddings tensor known and verify padding values get
531   // subtracted.  E.g., if padding is ((1,10),(2,20),(3,30)) then
532   // values 11,22,23 are subtracted to input dims to get output.
533   Tensor paddings_t(DT_INT64, TensorShape{3, 2});
534   test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
535   op.input_tensors[1] = &paddings_t;
536 
537   INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]");
538   INFER_OK(op, "[111,?,333];[3,2]", "[100,?,300]");
539 }
540 
TEST(ArrayOpsTest,BroadcastArgs_ShapeFn)541 TEST(ArrayOpsTest, BroadcastArgs_ShapeFn) {
542   ShapeInferenceTestOp op("BroadcastArgs");
543   INFER_OK(op, "?;?", "[?]");
544   INFER_OK(op, "[123];[1]", "[123]");
545   INFER_OK(op, "[1];[123]", "[123]");
546   INFER_OK(op, "[123];[121]", "[123]");
547 
548   // Rank checks
549   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
550   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
551 }
552 
TEST(ArrayOpsTest,BroadcastTo_ShapeFn)553 TEST(ArrayOpsTest, BroadcastTo_ShapeFn) {
554   ShapeInferenceTestOp op("BroadcastTo");
555   op.input_tensors.resize(2);
556 
557   INFER_OK(op, "?;[?]", "?");
558   INFER_OK(op, "[];[1]", "[?]");
559   INFER_OK(op, "[1];[1]", "[?]");
560   INFER_OK(op, "[1];[2]", "[?,?]");
561   INFER_OK(op, "[2,2];[3]", "[?,d0_0,d0_1]");
562 
563   // Rank checks
564   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[?,?]");
565   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[2];[]");
566   INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[2,2];[1]");
567 
568   Tensor shape_t(DT_INT64, TensorShape{3});
569   test::FillValues<int64>(&shape_t, {2, 10, 3});
570   op.input_tensors[1] = &shape_t;
571   INFER_OK(op, "[1,?,1];[3]", "[2,10,3]");
572   INFER_OK(op, "[1,1,1];[3]", "[2,10,3]");
573   INFER_OK(op, "[10,1];[3]", "[2,d0_0,3]");
574   INFER_ERROR("Dimensions must be equal, but are 3 and 2 for", op,
575               "[3,1,1];[3]");
576   INFER_ERROR("Dimensions must be equal, but are 2 and 10 for", op,
577               "[2,2,1];[3]");
578 }
579 
TEST(ArrayOpsTest,BroadcastGradientArgs_ShapeFn)580 TEST(ArrayOpsTest, BroadcastGradientArgs_ShapeFn) {
581   ShapeInferenceTestOp op("BroadcastGradientArgs");
582   // Output is always two unknown vectors.
583   INFER_OK(op, "?;?", "[?];[?]");
584   INFER_OK(op, "[123];[456]", "[?];[?]");
585 
586   // Rank checks
587   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
588   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
589 }
590 
TEST(ArrayOpsTest,ListDiff_ShapeFn)591 TEST(ArrayOpsTest, ListDiff_ShapeFn) {
592   ShapeInferenceTestOp op("BroadcastGradientArgs");
593   // Output is always two matching unknown vectors.
594   INFER_OK(op, "?;?", "[?];[?]");
595   INFER_OK(op, "[123];[456]", "[?];[?]");
596 
597   // Rank checks
598   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "[];?");
599   INFER_ERROR("Shape must be rank 1 but is rank 0", op, "?;[]");
600 }
601 
TEST(ArrayOpsTest,MatrixSetDiag_ShapeFn)602 TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
603   ShapeInferenceTestOp op("MatrixSetDiag");
604 
605   // Inputs are input and diagonal.
606 
607   // Rank checks.
608   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1];?");
609   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "?;[]");
610   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[2,2];[]");
611   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,2];[2,2]");
612 
613   // diagonal[-1] must match smallest matrix dimension.
614   INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
615 
616   // Output matches input.
617   INFER_OK(op, "?;?", "in0");
618   INFER_OK(op, "[1,2,2];[1,2]", "in0");
619   INFER_OK(op, "[1,2,3];?", "in0");
620   INFER_OK(op, "[1,3,2];?", "in0");
621   INFER_OK(op, "[1,?,2];[?,?]", "in0");
622   INFER_OK(op, "[1,?,?];[?,2]", "in0");
623 
624   // Infer batch shape from diag when input is not fully specified.
625   INFER_OK(op, "?;[1,2]", "[d1_0,?,?]");
626   INFER_OK(op, "[?,?,3];[1,2]", "[d1_0,d0_1,d0_2]");
627   INFER_OK(op, "[?,3,?];[1,2]", "[d1_0,d0_1,d0_2]");
628   INFER_OK(op, "[?,3,2];[1,2]", "[d1_0,d0_1,d0_2]");
629 }
630 
TEST(ArrayOpsTest,ExpandDims_ShapeFn)631 TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
632   ShapeInferenceTestOp op("ExpandDims");
633   op.input_tensors.resize(2);
634 
635   // With unknown dim tensor value, output is unknown.
636   INFER_OK(op, "?;?", "?");
637   Tensor dim_t;
638   op.input_tensors[1] = &dim_t;
639 
640   // Expand at front of tensor.
641   for (int32_t idx : {0, -4}) {
642     dim_t = test::AsScalar<int32>(idx);
643     INFER_OK(op, "?;?", "?");
644     INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
645   }
646 
647   // Expand at middle of tensor.
648   for (int32_t idx : {1, -3}) {
649     dim_t = test::AsScalar<int32>(idx);
650     INFER_OK(op, "?;?", "?");
651     INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
652 
653     // Repeat with int64.
654     dim_t = test::AsScalar<int64>(idx);
655     INFER_OK(op, "?;?", "?");
656     INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
657   }
658   for (int32_t idx : {2, -2}) {
659     dim_t = test::AsScalar<int32>(idx);
660     INFER_OK(op, "?;?", "?");
661     INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
662 
663     // Repeat with int64.
664     dim_t = test::AsScalar<int64>(idx);
665     INFER_OK(op, "?;?", "?");
666     INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
667   }
668 
669   for (int32_t idx : {3, -1}) {
670     // Expand at the end.
671     dim_t = test::AsScalar<int32>(idx);
672     INFER_OK(op, "?;?", "?");
673     INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
674 
675     // Repeat with int64.
676     dim_t = test::AsScalar<int64>(idx);
677     INFER_OK(op, "?;?", "?");
678     INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
679   }
680   for (int32_t idx : {4, -5}) {
681     // Invalid idx.
682     dim_t = test::AsScalar<int32>(idx);
683     INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
684     dim_t = test::AsScalar<int64>(idx);
685     INFER_ERROR("not in the interval [-4, 3]", op, "[5,?,7];?");
686   }
687 
688   // Expand using an input vector tensor.
689   std::vector<int32> dims;
690   dims.push_back(0);
691   dim_t = test::AsTensor<int32>(dims);
692   INFER_OK(op, "?;?", "?");
693   INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
694 
695   // Expand using too many input elements.
696   dims.push_back(1);
697   dim_t = test::AsTensor<int32>(dims);
698   INFER_ERROR("'dim' input must be a tensor with a single", op, "?;?");
699   INFER_ERROR("'dim' input must be a tensor with a single", op, "[5,6,7];?");
700 
701   // Examples from ExpandDims doc.
702   dim_t = test::AsScalar<int32>(0);
703   INFER_OK(op, "[2];[]", "[1,d0_0]");
704   dim_t = test::AsScalar<int32>(1);
705   INFER_OK(op, "[2];[]", "[d0_0,1]");
706   dim_t = test::AsScalar<int32>(-1);
707   INFER_OK(op, "[2];[]", "[d0_0,1]");
708 }
709 
TEST(ArrayOpsTest,ImmutableConst_ShapeFn)710 TEST(ArrayOpsTest, ImmutableConst_ShapeFn) {
711   ShapeInferenceTestOp op("ImmutableConst");
712 
713   TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
714                    .Attr("dtype", DT_FLOAT)
715                    .Attr("shape", TensorShape({1, 2, 3}))
716                    .Attr("memory_region_name", "test_region")
717                    .Finalize(&op.node_def));
718   INFER_OK(op, "", "[1,2,3]");
719 
720   TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
721                    .Attr("dtype", DT_FLOAT)
722                    .Attr("shape", TensorShape({}))
723                    .Attr("memory_region_name", "test_region")
724                    .Finalize(&op.node_def));
725   INFER_OK(op, "", "[]");
726 
727   TF_ASSERT_OK(NodeDefBuilder("test", "ImmutableConst")
728                    .Attr("dtype", DT_FLOAT)
729                    .Attr("shape", "invalid")
730                    .Attr("memory_region_name", "test_region")
731                    .Finalize(&op.node_def));
732   INFER_ERROR("AttrValue had value with type 'string' when 'shape' expected",
733               op, "");
734 }
735 
TEST(ArrayOpsTest,Concat_ShapeFn)736 TEST(ArrayOpsTest, Concat_ShapeFn) {
737   ShapeInferenceTestOp op("Concat");
738   auto set_n = [&op](int n) {
739     std::vector<NodeDefBuilder::NodeOut> src_list;
740     src_list.reserve(n);
741     for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
742     TF_ASSERT_OK(NodeDefBuilder("test", "Concat")
743                      .Input({"concat_dim", 0, DT_INT32})
744                      .Input(src_list)
745                      .Attr("n", n)
746                      .Finalize(&op.node_def));
747   };
748 
749   // Confirm dimension[0] of the input (the concat_dim) is a scalar.
750   set_n(2);
751   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?");
752 
753   // Test with the input concat_dim tensor not known. This takes the known rank
754   // of the inputs and makes a tensor of that many unknown dims.
755   set_n(7);
756   INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
757   set_n(4);
758   INFER_OK(op, "?;?;?;[1,2,3,4];[4,3,2,1]", "[?,?,?,?]");
759   INFER_OK(op, "?;?;?;?;?", "?");  // output rank unknown
760   INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
761               "?;?;?;[];[]");
762   INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;?;[1,2];[1,2,3]");
763 
764   // Test when the concat_dim tensor is known. The concatenated dimension is
765   // summed across all input tensors, and other dimensions are merged.
766   Tensor concat_dim_t;
767   op.input_tensors.push_back(&concat_dim_t);
768   set_n(2);
769 
770   // Sum dim 0, merge the other two dims.
771   for (int concat_dim : {0, -3}) {
772     concat_dim_t = test::AsScalar(concat_dim);
773     INFER_OK(op, "[];[100,2,?];[10,?,3]", "[110,d1_1,d2_2]");
774     INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
775                 "[];[100,2,5];[10,?,3]");
776     // concat_dim can't be summed, as one value is unknown.
777     INFER_OK(op, "[];[100,2,?];[?,?,3]", "[?,d1_1,d2_2]");
778     INFER_OK(op, "[];[?,2,?];[10,?,3]", "[?,d1_1,d2_2]");
779   }
780 
781   // Test with a higher concat_dim.
782   for (bool use_negative : {false, true}) {
783     concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
784     INFER_OK(op, "[];[1,100,?];[?,10,3]", "[d1_0,110,d2_2]");
785     concat_dim_t = test::AsScalar(use_negative ? -1 : 1);
786     INFER_OK(op, "[];[1,100];[?,10]", "[d1_0,110]");
787     INFER_OK(op, "[];[?,100];[1,10]", "[d2_0,110]");
788 
789     // concat_dim is out of bounds.
790     concat_dim_t = test::AsScalar(use_negative ? -2 : 1);
791     INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
792                 "[];[100];[10,?]");
793     INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
794                 "[];[100,5];[10]");
795   }
796 
797   // concat_dim is too low.
798   concat_dim_t = test::AsScalar(-2);
799   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
800               "[];[100];[10,?]");
801   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
802               "[];[100,5];[10]");
803 
804   // Repeat successful case with several unknown inputs.
805   set_n(5);
806   concat_dim_t = test::AsScalar(1);
807   INFER_OK(op, "[];?;[1,100,?];[?,?,?];[?,10,3];?", "[d2_0,?,d4_2]");
808 }
809 
TEST(ArrayOpsTest,ConcatV2_ShapeFn)810 TEST(ArrayOpsTest, ConcatV2_ShapeFn) {
811   ShapeInferenceTestOp op("ConcatV2");
812   auto set_n = [&op](int n) {
813     std::vector<NodeDefBuilder::NodeOut> src_list;
814     src_list.reserve(n);
815     for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
816     TF_ASSERT_OK(NodeDefBuilder("test", "ConcatV2")
817                      .Input(src_list)
818                      .Input({"axis", 0, DT_INT32})
819                      .Attr("n", n)
820                      .Finalize(&op.node_def));
821   };
822 
823   // Confirm dimension[0] of the input (the concat_dim) is a scalar.
824   set_n(2);
825   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;?;[1]");
826 
827   // Test with the input concat_dim tensor not known. This takes the known rank
828   // of the inputs and makes a tensor of that many unknown dims.
829   set_n(7);
830   INFER_OK(op, "?;?;?;?;[1,2,3];?;[3,2,1];?", "[?,?,?]");
831   set_n(4);
832   INFER_OK(op, "?;?;[1,2,3,4];[4,3,2,1];?", "[?,?,?,?]");
833   INFER_OK(op, "?;?;?;?;?", "?");  // output rank unknown
834   INFER_ERROR("Can't concatenate scalars (use tf.stack instead)", op,
835               "?;?;[];[];?");
836   INFER_ERROR("Shape must be rank 2 but is rank 3", op, "?;?;[1,2];[1,2,3];?");
837 
838   // Test when the concat_dim tensor is known. The concatenated dimension is
839   // summed across all input tensors, and other dimensions are merged.
840   Tensor concat_dim_t;
841   op.input_tensors.resize(3);
842   op.input_tensors[2] = &concat_dim_t;
843 
844   set_n(2);
845 
846   // Invalid concat dim value.
847   // concat_dim_t = test::AsScalar(-1);
848   // INFER_ERROR("Expected concat_dim >= 0, but got -1", op, "?;?;?");
849 
850   // Sum dim 0, merge the other two dims.
851   concat_dim_t = test::AsScalar(0);
852   INFER_OK(op, "[100,2,?];[10,?,3];[]", "[110,d0_1,d1_2]");
853   INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
854               "[100,2,5];[10,?,3];[]");
855   // concat_dim can't be summed, as one value is unknown.
856   INFER_OK(op, "[100,2,?];[?,?,3];[]", "[?,d0_1,d1_2]");
857   INFER_OK(op, "[?,2,?];[10,?,3];[]", "[?,d0_1,d1_2]");
858 
859   // Test with a higher concat_dim.
860   concat_dim_t = test::AsScalar(1);
861   INFER_OK(op, "[1,100,?];[?,10,3];[]", "[d0_0,110,d1_2]");
862   INFER_OK(op, "[1,100];[?,10];[]", "[d0_0,110]");
863   INFER_OK(op, "[?,100];[1,10];[]", "[d1_0,110]");
864   // concat_dim is too high.
865   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
866               "[100];[10,?];[]");
867   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
868               "[100,5];[10];[]");
869   // concat_dim is too low.
870   concat_dim_t = test::AsScalar(-2);
871   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
872               "[100];[10,?];[]");
873   INFER_ERROR("Shape must be at least rank 2 but is rank 1", op,
874               "[100,5];[10];[]");
875 
876   // Repeat successful case with several unknown inputs.
877   op.input_tensors.resize(6);
878   op.input_tensors[3] = nullptr;
879   op.input_tensors[5] = &concat_dim_t;
880   concat_dim_t = test::AsScalar(1);
881 
882   set_n(5);
883   INFER_OK(op, "?;[1,100,?];[?,?,?];[?,10,3];?;[]", "[d1_0,?,d3_2]");
884 }
885 
TEST(ArrayOpsTest,ConcatOffset_ShapeFn)886 TEST(ArrayOpsTest, ConcatOffset_ShapeFn) {
887   ShapeInferenceTestOp op("ConcatOffset");
888 
889   const int n = 4;
890   std::vector<NodeDefBuilder::NodeOut> src_list;
891   src_list.reserve(n);
892   for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_INT32);
893   TF_ASSERT_OK(NodeDefBuilder("test", "ConcatOffset")
894                    .Input({"concat_dim", 0, DT_INT32})
895                    .Input(src_list)
896                    .Attr("n", n)
897                    .Finalize(&op.node_def));
898   INFER_OK(op, "?;?;?;?;?", "in1;in2;in3;in4");
899 }
900 
TEST(ArrayOpsTest,Reshape_ShapeFn)901 TEST(ArrayOpsTest, Reshape_ShapeFn) {
902   ShapeInferenceTestOp op("Reshape");
903   op.input_tensors.resize(2);
904 
905   // No valid shape provided.
906   INFER_OK(op, "?;?", "?");
907   INFER_OK(op, "[?];?", "?");
908   INFER_OK(op, "?;[?]", "?");
909   INFER_OK(op, "[?];[?]", "?");
910   INFER_OK(op, "[4];[?]", "?");
911 
912   // All dimensions provided.
913   Tensor new_shape = test::AsTensor<int32>({1, 2, 3});
914   op.input_tensors[1] = &new_shape;
915   INFER_OK(op, "?;[3]", "[1,2,3]");
916   INFER_OK(op, "[?];[3]", "[1,2,3]");
917   INFER_OK(op, "[6];[3]", "[1,2,3]");
918   // The number of elements should match for the reshape to succeed.
919   INFER_ERROR(
920       "Cannot reshape a tensor with 12 elements to shape [1,2,3] (6 elements)",
921       op, "[3,4];[3]");
922 
923   // Unknown dimensions.
924   // Flatten:
925   new_shape = test::AsTensor<int32>({-1});
926   INFER_OK(op, "?;[1]", "[?]");
927   INFER_OK(op, "[?];[1]", "[d0_0]");
928   INFER_OK(op, "[2,2];[1]", "[4]");
929   // The first dimension is inferred:
930   new_shape = test::AsTensor<int32>({2, -1});
931   INFER_OK(op, "[3,4];[2]", "[2,6]");
932   // The total number of elements must be evenly divisible by the known
933   // dimensions.
934   INFER_ERROR("Dimension size must be evenly divisible by 2 but is 7", op,
935               "[7];[2]");
936   // Multiple missing dimensions cannot be inferred.
937   new_shape = test::AsTensor<int32>({-1, -1, 2});
938   INFER_OK(op, "[8];[3]", "[?,?,2]");
939   INFER_OK(op, "?;[3]", "[?,?,2]");
940 
941   // Symbolic shape propagation
942   new_shape = test::AsTensor<int32>({-1, 2, 3});
943   INFER_OK(op, "[?,2,3];[3]", "[d0_0,2,3]");
944 
945   // Reshaping to a scalar.
946   new_shape = test::AsTensor<int32>({});
947   INFER_OK(op, "[1];[0]", "[]");
948   INFER_ERROR(
949       "Cannot reshape a tensor with 2 elements to shape [] (1 elements)", op,
950       "[1,2];[0]");
951 
952   // Reshaping a tensor with no elements.
953   new_shape = test::AsTensor<int32>({-1});
954   INFER_OK(op, "[0];[1]", "[0]");
955   new_shape = test::AsTensor<int32>({-1, 6});
956   INFER_OK(op, "[0,2];[1]", "[0,6]");
957   new_shape = test::AsTensor<int32>({0, -1});
958   INFER_OK(op, "[0,2];[1]", "[0,?]");
959 }
960 
TEST(ArrayOpsTest,QuantizedReshape_ShapeFn)961 TEST(ArrayOpsTest, QuantizedReshape_ShapeFn) {
962   ShapeInferenceTestOp op("QuantizedReshape");
963   op.input_tensors.resize(2);
964 
965   // First test a subset of the Reshape_ShapeFn tests. Not all are tested, as
966   // QuantizedReshape uses the same code for the reshape part of the operation.
967   INFER_OK(op, "?;?;?;?", "?;[];[]");
968   INFER_OK(op, "[?];?;?;?", "?;[];[]");
969   INFER_OK(op, "[?];[?];?;?", "?;[];[]");
970   INFER_OK(op, "[4];[?];?;?", "?;[];[]");
971   Tensor new_shape = test::AsTensor<int32>({1, 2, 3});
972   op.input_tensors[1] = &new_shape;
973   INFER_OK(op, "[?];[3];?;?", "[1,2,3];[];[]");
974   INFER_OK(op, "[6];[3];?;?", "[1,2,3];[];[]");
975   INFER_ERROR(
976       "Cannot reshape a tensor with 12 elements to shape [1,2,3] (6 elements)",
977       op, "[3,4];[3];?;?");
978 
979   // Test the scalar rank checks on input_min and input_max.
980   INFER_ERROR("must be rank 0", op, "?;?;[1];?");
981   INFER_ERROR("must be rank 0", op, "?;?;?;[1]");
982 }
983 
TEST(ArrayOpsTest,Placeholder_ShapeFn)984 TEST(ArrayOpsTest, Placeholder_ShapeFn) {
985   {
986     // 2D shape
987     ShapeInferenceTestOp op("Placeholder");
988     TensorShape shape({1, 2});
989     TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
990                      .Attr("shape", shape)
991                      .Attr("dtype", DT_FLOAT)
992                      .Finalize(&op.node_def));
993     INFER_OK(op, "", "[1,2]");
994   }
995 
996   {
997     // Scalar shapes are supported
998     ShapeInferenceTestOp op("Placeholder");
999     TensorShape shape({});
1000     TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
1001                      .Attr("shape", shape)
1002                      .Attr("dtype", DT_FLOAT)
1003                      .Finalize(&op.node_def));
1004     INFER_OK(op, "", "[]");
1005   }
1006 
1007   {
1008     // Partial shape
1009     ShapeInferenceTestOp op("Placeholder");
1010     const int64 dims[2] = {1, -1};
1011     PartialTensorShape shape;
1012     TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 2, &shape));
1013     TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
1014                      .Attr("shape", shape)
1015                      .Attr("dtype", DT_FLOAT)
1016                      .Finalize(&op.node_def));
1017     INFER_OK(op, "", "[1,?]");
1018   }
1019 
1020   {
1021     // Unknown shape
1022     ShapeInferenceTestOp op("Placeholder");
1023     PartialTensorShape shape;
1024     TF_ASSERT_OK(NodeDefBuilder("test", "Placeholder")
1025                      .Attr("shape", shape)
1026                      .Attr("dtype", DT_FLOAT)
1027                      .Finalize(&op.node_def));
1028     INFER_OK(op, "", "?");
1029   }
1030 }
1031 
TEST(ArrayOpsTest,Transpose_ShapeFn)1032 TEST(ArrayOpsTest, Transpose_ShapeFn) {
1033   ShapeInferenceTestOp op("Transpose");
1034   op.input_tensors.resize(2);
1035 
1036   // Missing shape information.
1037   INFER_OK(op, "?;?", "?");
1038   INFER_OK(op, "?;[?]", "?");
1039   INFER_OK(op, "?;[2]", "[?,?]");
1040   INFER_OK(op, "[?];?", "[?]");
1041   INFER_OK(op, "[?,?];[2]", "[?,?]");
1042   INFER_ERROR("Dimension must be 3 but is 2", op, "[1,2,3];[2]");
1043   Tensor perm = test::AsTensor<int32>({0});
1044   op.input_tensors[1] = &perm;
1045   INFER_OK(op, "[?];[?]", "[d0_0]");
1046   perm = test::AsTensor<int32>({1, 0});
1047   INFER_OK(op, "?;[2]", "[?,?]");
1048   INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
1049   INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
1050   INFER_OK(op, "?;[0]", "in0");
1051 
1052   // Invalid arguments.
1053   perm = test::AsTensor<int32>({1, 2});
1054   INFER_ERROR("perm dim 2 is out of range of input rank 2", op, "[1,2];[2]");
1055   perm = test::AsTensor<int32>({0});
1056   INFER_ERROR("Dimension must be 2 but is 1", op, "[1,2];[1]");
1057 
1058   // Larger valid cases.
1059   perm = test::AsTensor<int32>({1, 0, 3, 4, 2});
1060   INFER_OK(op, "[0,1,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
1061   INFER_OK(op, "[0,?,2,3,4];[5]", "[d0_1,d0_0,d0_3,d0_4,d0_2]");
1062 }
1063 
TEST(ArrayOpsTest,Bitcast_ShapeFn)1064 TEST(ArrayOpsTest, Bitcast_ShapeFn) {
1065   ShapeInferenceTestOp op("Bitcast");
1066   auto rebuild_node_def = [&op](DataType input_type, DataType output_type) {
1067     TF_ASSERT_OK(NodeDefBuilder("test", "Bitcast")
1068                      .Input("input", 0, input_type)
1069                      .Attr("type", output_type)
1070                      .Finalize(&op.node_def));
1071   };
1072 
1073   rebuild_node_def(DT_FLOAT, DT_INT32);
1074   // No valid shape provided, so output is unknown.
1075   INFER_OK(op, "?", "?");
1076 
1077   // Bitcasting from two equal sizes propagates shape.
1078   INFER_OK(op, "[1,2]", "in0");
1079 
1080   // Bitcasting from smaller to larger reduces the size of the last dimension.
1081   rebuild_node_def(DT_INT32, DT_INT64);
1082   INFER_OK(op, "[1,2]", "[d0_0]");  // last dimension matches divisor.
1083   // TODO(vrv): Seems like a bug, or at least, too lenient.
1084   INFER_OK(op, "[1,?]", "[d0_0]");
1085   // 4 is divisible by 2, but the shape function signature requires
1086   // that the last dimension matches the last value exactly.
1087   INFER_ERROR("does not match", op, "[1,4]");
1088   INFER_ERROR("does not match", op, "[1,3]");
1089 
1090   // Bitcasting from a larger type to a smaller type extends the dimension
1091   rebuild_node_def(DT_INT64, DT_INT32);
1092   INFER_OK(op, "[4,5]", "[d0_0,d0_1,2]");
1093   rebuild_node_def(DT_COMPLEX128, DT_INT32);
1094   INFER_OK(op, "[4,5]", "[d0_0,d0_1,4]");
1095   rebuild_node_def(DT_COMPLEX128, DT_HALF);
1096   INFER_OK(op, "[4,5]", "[d0_0,d0_1,8]");
1097   rebuild_node_def(DT_COMPLEX128, DT_INT8);
1098   INFER_OK(op, "[4,5]", "[d0_0,d0_1,16]");
1099 
1100   // Bitcasting from a POD or quantized datatype is not allowed.
1101   rebuild_node_def(DT_STRING, DT_INT32);
1102   INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1103   rebuild_node_def(DT_INT32, DT_STRING);
1104   INFER_ERROR("one of the type sizes is zero", op, "[1,2,3]");
1105 }
1106 
TEST(ArrayOpsTest,Squeeze_ShapeFn)1107 TEST(ArrayOpsTest, Squeeze_ShapeFn) {
1108   ShapeInferenceTestOp op("Squeeze");
1109 
1110   auto rebuild_node_def = [&op](const std::vector<int32>& squeeze_dims) {
1111     TF_ASSERT_OK(NodeDefBuilder("test", "Squeeze")
1112                      .Input("input", 0, DT_FLOAT)
1113                      .Attr("squeeze_dims", squeeze_dims)
1114                      .Finalize(&op.node_def));
1115   };
1116 
1117   // Default squeeze_dims = []
1118   rebuild_node_def({});
1119 
1120   // No valid shape provided, so output is unknown.
1121   INFER_OK(op, "?", "?");
1122 
1123   INFER_OK(op, "[1,4,1,5,1]", "[d0_1,d0_3]");
1124 
1125   // Squeezing all dimensions, but see some unknown values.
1126   INFER_OK(op, "[1,?,1,?,1]", "?");
1127 
1128   // Test simple squeeze of an explicit dimension
1129   rebuild_node_def({1});
1130   INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1131   // Squeezing unknown dim explicitly, assumes it's 1 at runtime.
1132   INFER_OK(op, "[4,?,5]", "[d0_0,d0_2]");
1133 
1134   // Attempt to squeeze non-one dimension
1135   INFER_ERROR("Can not squeeze dim[1]", op, "[4,6,5]");
1136 
1137   // Squeeze multiple dimensions
1138   rebuild_node_def({1, 2});
1139   INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1140   rebuild_node_def({1, -2});
1141   INFER_OK(op, "[4,1,1,5]", "[d0_0,d0_3]");
1142 
1143   // Negative squeeze dim
1144   rebuild_node_def({-2});
1145   INFER_OK(op, "[4,1,5]", "[d0_0,d0_2]");
1146 
1147   // Test validation of squeeze dimensions
1148   rebuild_node_def({-4});
1149   INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1150   rebuild_node_def({3});
1151   INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
1152 }
1153 
TEST(ArrayOpsTest,ReverseSequence_ShapeFn)1154 TEST(ArrayOpsTest, ReverseSequence_ShapeFn) {
1155   ShapeInferenceTestOp op("ReverseSequence");
1156   auto rebuild_node_def = [&op](const int32_t seq_dim,
1157                                 const int32_t batch_dim) {
1158     TF_ASSERT_OK(NodeDefBuilder("test", "ReverseSequence")
1159                      .Input("input", 0, DT_FLOAT)
1160                      .Input("seq_lengths", 1, DT_INT64)
1161                      .Attr("seq_dim", seq_dim)
1162                      .Attr("batch_dim", batch_dim)
1163                      .Finalize(&op.node_def));
1164   };
1165 
1166   rebuild_node_def(1, 2);
1167   // No valid shape provided, so output is unknown.
1168   INFER_OK(op, "?;[10]", "?");
1169 
1170   // Bad rank for seq_lengths
1171   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[10,10]");
1172 
1173   // Validate seq_dim and batch_dim
1174   rebuild_node_def(1, 4);
1175   INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
1176   rebuild_node_def(4, 1);
1177   INFER_ERROR("seq_dim must be < input rank", op, "[1,2,3];[3]");
1178 
1179   rebuild_node_def(1, 2);
1180   INFER_OK(op, "[1,2,3];[3]", "[d0_0,d0_1,d0_2]");
1181   // Resolves uncertainty on batch dimension by merging.
1182   INFER_OK(op, "[1,2,?];[3]", "[d0_0,d0_1,d1_0]");
1183   INFER_OK(op, "[1,2,3];[?]", "[d0_0,d0_1,d0_2]");
1184 }
1185 
TEST(ArrayOpsTest,Split_ShapeFn)1186 TEST(ArrayOpsTest, Split_ShapeFn) {
1187   ShapeInferenceTestOp op("Split");
1188   op.input_tensors.resize(2);
1189 
1190   // No value for split_dim and no input.
1191   TF_ASSERT_OK(NodeDefBuilder("test", "Split")
1192                    .Input("split_dim", 0, DT_INT32)
1193                    .Input("value", 1, DT_FLOAT)
1194                    .Attr("num_split", 2)
1195                    .Finalize(&op.node_def));
1196   INFER_OK(op, "?;?", "?;?");
1197   // If the rank is known, we know the rank of each output.
1198   INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
1199 
1200   // split_dim is unknown but other inputs are known.
1201   INFER_OK(op, "?;[1,4]", "[?,?];[?,?]");
1202 
1203   // split_dim is known.
1204   Tensor split_dim = test::AsTensor<int32>({1, 2});
1205   op.input_tensors[0] = &split_dim;
1206   INFER_ERROR("Input must be scalar but has rank 1", op, "[?];[?,?]");
1207   split_dim = test::AsScalar<int32>(1);
1208   INFER_OK(op, "?;?", "?;?");
1209   INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1210   INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1211   INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1212   INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1213               "?;[1,5]");
1214 
1215   // split_dim too large.
1216   split_dim = test::AsScalar<int32>(3);
1217   INFER_ERROR(
1218       "Dimension size, given by scalar input 3 must be in range [-3, 3)", op,
1219       "?;[1,4,8]");
1220 
1221   // Negative split_dim.
1222   split_dim = test::AsScalar<int32>(-1);
1223   INFER_OK(op, "?;?", "?;?");
1224   INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
1225   INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
1226   INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
1227   INFER_OK(op, "?;[1,4,8]", "[d1_0,d1_1,4];[d1_0,d1_1,4]");
1228   split_dim = test::AsScalar<int32>(-2);
1229   INFER_OK(op, "?;[1,4,8]", "[d1_0,2,d1_2];[d1_0,2,d1_2]");
1230   split_dim = test::AsScalar<int32>(-4);
1231   INFER_ERROR(
1232       "Dimension size, given by scalar input -4 must be in range [-3, 3)", op,
1233       "?;[1,4,8]");
1234 }
1235 
TEST(ArrayOpsTest,Tile_ShapeFn)1236 TEST(ArrayOpsTest, Tile_ShapeFn) {
1237   ShapeInferenceTestOp op("Tile");
1238   op.input_tensors.resize(2);
1239 
1240   // No value for split_dim and no input.
1241   TF_ASSERT_OK(NodeDefBuilder("test", "Tile")
1242                    .Input("input", 0, DT_FLOAT)
1243                    .Input("multiples", 1, DT_INT32)
1244                    .Finalize(&op.node_def));
1245 
1246   // If both are unknown, output is unknown.
1247   INFER_OK(op, "?;?", "?");
1248 
1249   // If multiples rank is unknown but input is, output rank is known.
1250   INFER_OK(op, "[2,3,1,4];?", "[?,?,?,?]");
1251 
1252   // Bad rank for 'multiples'
1253   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[2,3,1,4];[4,1]");
1254 
1255   // No multiples tensor available, but output rank is known from multiples.
1256   INFER_OK(op, "?;[4]", "[?,?,?,?]");
1257 
1258   // Test a tile of a 4D input.
1259   Tensor multiples = test::AsTensor<int32>({2, 3, 4, 5});
1260   op.input_tensors[1] = &multiples;
1261   INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1262   // Test 64-bit tensor type
1263   multiples = test::AsTensor<int64>({2, 3, 4, 5});
1264   INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
1265 }
1266 
TEST(ArrayOpsTest,EditDistance_ShapeFn)1267 TEST(ArrayOpsTest, EditDistance_ShapeFn) {
1268   ShapeInferenceTestOp op("EditDistance");
1269   op.input_tensors.resize(6);
1270 
1271   // If the shape tensors are not available, the output shape is unknown.
1272   INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "?");
1273 
1274   Tensor hypothesis_shape = test::AsTensor<int64>({2, 30, 4, 50});
1275   op.input_tensors[2] = &hypothesis_shape;
1276   Tensor truth_shape = test::AsTensor<int64>({20, 3, 40, 5});
1277   op.input_tensors[5] = &truth_shape;
1278   INFER_OK(op, "[?,?];[?];[4];[?,?];[?];[4]", "[20,30,40]");
1279 
1280   // Shape elements don't match
1281   hypothesis_shape = test::AsTensor<int64>({2});
1282   op.input_tensors[2] = &hypothesis_shape;
1283   INFER_ERROR("Num elements of hypothesis_shape does not match truth_shape", op,
1284               "[?,?];[?];[1];[?,?];[?];[4]");
1285 }
1286 
TEST(ArrayOpsTest,OneHot_ShapeFn)1287 TEST(ArrayOpsTest, OneHot_ShapeFn) {
1288   ShapeInferenceTestOp op("OneHot");
1289   op.input_tensors.resize(4);
1290   auto set_axis = [&op](int axis) {
1291     TF_ASSERT_OK(NodeDefBuilder("test", "OneHot")
1292                      .Input("indices", 0, DT_FLOAT)
1293                      .Input("depth", 1, DT_INT32)
1294                      .Input("on_value", 2, DT_FLOAT)
1295                      .Input("off_value", 3, DT_FLOAT)
1296                      .Attr("axis", axis)
1297                      .Finalize(&op.node_def));
1298   };
1299 
1300   // Invalid axis value.
1301   set_axis(-2);
1302   INFER_ERROR("axis must be >= -1", op, "?;?;?;?");
1303   set_axis(1);
1304 
1305   // If indices shape is unknown, we return an unknown shape.
1306   INFER_OK(op, "?;[];?;?", "?");
1307 
1308   // Depth must be scalar.
1309   Tensor depth = test::AsTensor<int32>({1, 2});
1310   op.input_tensors[1] = &depth;
1311   INFER_ERROR("Input must be scalar but has rank 1", op, "?;[2];?;?");
1312 
1313   // Full information is available.
1314   depth = test::AsScalar<int32>(2);
1315   INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,2,d0_1,d0_2]");
1316   set_axis(-1);
1317   INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
1318 }
1319 
TEST(ArrayOpsTest,ExtractImagePatchesShapeTest)1320 TEST(ArrayOpsTest, ExtractImagePatchesShapeTest) {
1321   ShapeInferenceTestOp op("ExtractImagePatches");
1322   auto set_op = [&op](const std::vector<int32>& ksizes,
1323                       const std::vector<int32>& strides,
1324                       const std::vector<int32>& rates, const string& padding) {
1325     TF_ASSERT_OK(NodeDefBuilder("test", "ExtractImagePatches")
1326                      .Input("input", 0, DT_FLOAT)
1327                      .Attr("ksizes", ksizes)
1328                      .Attr("strides", strides)
1329                      .Attr("rates", rates)
1330                      .Attr("padding", padding)
1331                      .Finalize(&op.node_def));
1332   };
1333 
1334   // Just tests that the ksize calculation with rates works.  Most of
1335   // the other code is boilerplate that is tested by a variety of
1336   // other ops.
1337   //
1338   // ksizes is 2x2.  rate rows and cols is 2, so ksize_rows and
1339   // cols are changed to be 2 + (2 - 1) = 3.  7x7 input with 3x3
1340   // filter and 1x1 stride gives a 5x5 output.
1341   set_op({1, 2, 2, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1342   INFER_OK(op, "[1,7,7,2]", "[d0_0,5,5,8]");
1343   // With ksizes as 1x1, the output depth is now exactly the last value of the
1344   // input and output spatial is reduced as well.
1345   set_op({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1346   INFER_OK(op, "[1,7,7,2]", "[d0_0,7,7,d0_3]");
1347 
1348   // Bad ksize rank
1349   set_op({1, 2, 2, 1, 1}, {1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
1350   INFER_ERROR(
1351       "ExtractImagePatches requires the ksizes attribute to contain 4 values, "
1352       "but got: 5",
1353       op, "[1,7,7,2]");
1354 }
1355 
TEST(ArrayOpsTest,QuantizeAndDequantizeV2_ShapeFn)1356 TEST(ArrayOpsTest, QuantizeAndDequantizeV2_ShapeFn) {
1357   ShapeInferenceTestOp op("QuantizeAndDequantizeV2");
1358   op.input_tensors.resize(3);
1359   TF_ASSERT_OK(NodeDefBuilder("test", "QuantizeAndDequantizeV2")
1360                    .Input("input", 0, DT_FLOAT)
1361                    .Input("input_min", 1, DT_FLOAT)
1362                    .Input("input_max", 2, DT_FLOAT)
1363                    .Attr("signed_input", true)
1364                    .Attr("num_bits", 8)
1365                    .Attr("range_given", false)
1366                    .Attr("narrow_range", false)
1367                    .Attr("axis", -1)
1368                    .Finalize(&op.node_def));
1369   INFER_OK(op, "?;?;?", "in0");
1370   INFER_OK(op, "[];?;?", "in0");
1371   INFER_OK(op, "[1,2,?,4,5];?;?", "in0");
1372 
1373   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[]");
1374   INFER_ERROR("Shapes must be equal rank, but are 1 and 0", op,
1375               "[1,2,?,4,5];[];[1]");
1376   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1,2,?,4,5];[1];[1]");
1377 }
1378 
TEST(ArrayOpsTest,SpaceToBatch_ShapeFn)1379 TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) {
1380   ShapeInferenceTestOp op("SpaceToBatch");
1381   op.input_tensors.resize(2);
1382   TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToBatch")
1383                    .Input("input", 0, DT_FLOAT)
1384                    .Input("paddings", 1, DT_INT32)
1385                    .Attr("block_size", 2)
1386                    .Finalize(&op.node_def));
1387 
1388   // Paddings not known, but batch size can be computed.
1389   INFER_OK(op, "[1,10,10,3];[2,2]", "[4,?,?,d0_3]");
1390 
1391   // Unknown paddings means width and height.
1392   INFER_OK(op, "[1,10,10,3];?", "[4,?,?,d0_3]");
1393 
1394   // Paddings not correct shape
1395   INFER_ERROR("rank", op, "[1,10,10,3];[4]");
1396   INFER_ERROR("3 and 2", op, "[1,10,10,3];[2,3]");
1397 
1398   Tensor paddings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
1399   op.input_tensors[1] = &paddings;
1400   INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1401   paddings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
1402   INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
1403 
1404   // Bad paddings values
1405   paddings = test::AsTensor<int32>({1, 2, 3, 4}, {{2, 2}});
1406   op.input_tensors[1] = &paddings;
1407   INFER_ERROR("Dimension size must be evenly divisible by 2 but is 13", op,
1408               "[1,10,10,3];[2,2]");
1409 
1410   // Negative paddings
1411   paddings = test::AsTensor<int32>({1, -2, 3, 4}, {{2, 2}});
1412   op.input_tensors[1] = &paddings;
1413   INFER_ERROR("cannot be negative", op, "[1,10,10,3];[2,2]");
1414 }
1415 
TEST(ArrayOpsTest,SpaceToBatchND_ShapeFn)1416 TEST(ArrayOpsTest, SpaceToBatchND_ShapeFn) {
1417   ShapeInferenceTestOp op("SpaceToBatchND");
1418   op.input_tensors.resize(3);
1419   TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToBatchND")
1420                    .Input("input", 0, DT_FLOAT)
1421                    .Input("block_shape", 1, DT_INT32)
1422                    .Input("paddings", 2, DT_INT32)
1423                    .Finalize(&op.node_def));
1424 
1425   // Verify that input shape and paddings shape can be unknown.
1426   INFER_OK(op, "?;[2];?", "?");
1427 
1428   // Only number of input dimensions is known.
1429   INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1430 
1431   // Dimensions are partially known.
1432   INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1433 
1434   {
1435     // Dimensions are partially known, block_shape known.
1436     Tensor block_shape = test::AsTensor<int32>({2, 3});
1437     op.input_tensors[1] = &block_shape;
1438     INFER_OK(op, "[3,?,?,2];[2];?", "[18,?,?,d0_3]");
1439 
1440     // Dimensions are partially known, block_shape and paddings known.
1441     {
1442       Tensor paddings = test::AsTensor<int32>({1, 1, 0, 1}, {{2, 2}});
1443       op.input_tensors[2] = &paddings;
1444       INFER_OK(op, "[3,?,2,2];[2];[2,2]", "[18,?,1,d0_3]");
1445       op.input_tensors[2] = nullptr;
1446     }
1447 
1448     // Dimensions are fully known, block_shape and paddings are known.
1449     {
1450       Tensor paddings = test::AsTensor<int32>({1, 1, 0, 0}, {{2, 2}});
1451       op.input_tensors[2] = &paddings;
1452       INFER_OK(op, "[3,2,3,2];[2];[2,2]", "[18,2,1,d0_3]");
1453       op.input_tensors[2] = nullptr;
1454     }
1455 
1456     op.input_tensors[1] = nullptr;
1457   }
1458 
1459   INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1460   INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1461 
1462   {
1463     Tensor block_shape = test::AsTensor<int32>({0, 2});
1464     op.input_tensors[1] = &block_shape;
1465     INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1466     op.input_tensors[1] = nullptr;
1467   }
1468 
1469   {
1470     Tensor block_shape = test::AsTensor<int32>({1, 1});
1471     op.input_tensors[1] = &block_shape;
1472     Tensor paddings = test::AsTensor<int32>({0, -1, 0, 0}, {{2, 2}});
1473     op.input_tensors[2] = &paddings;
1474     INFER_ERROR("paddings cannot be negative", op, "[1,2,2];[2];[2,2]");
1475     op.input_tensors[1] = nullptr;
1476     op.input_tensors[2] = nullptr;
1477   }
1478 
1479   {
1480     Tensor block_shape = test::AsTensor<int32>({3, 3});
1481     op.input_tensors[1] = &block_shape;
1482     Tensor paddings = test::AsTensor<int32>({0, 0, 0, 0}, {{2, 2}});
1483     op.input_tensors[2] = &paddings;
1484     INFER_ERROR("divisible", op, "[1,2,3,1];[2];[2,2]");
1485     op.input_tensors[1] = nullptr;
1486     op.input_tensors[2] = nullptr;
1487   }
1488 
1489   {
1490     Tensor block_shape = test::AsTensor<int32>({});
1491     op.input_tensors[1] = &block_shape;
1492     Tensor paddings = test::AsTensor<int32>({});
1493     op.input_tensors[2] = &paddings;
1494     INFER_OK(op, "?;[0];[0,2]", "?");
1495     op.input_tensors[1] = nullptr;
1496     op.input_tensors[2] = nullptr;
1497   }
1498 
1499   INFER_ERROR("rank", op, "[1,3,3,1];[2];[1]");
1500   INFER_ERROR("shape", op, "[1,3,3,1];[2];[1,2]");
1501 }
1502 
TEST(ArrayOpsTest,BatchToSpace_ShapeFn)1503 TEST(ArrayOpsTest, BatchToSpace_ShapeFn) {
1504   ShapeInferenceTestOp op("BatchToSpace");
1505   op.input_tensors.resize(2);
1506   TF_ASSERT_OK(NodeDefBuilder("test", "BatchToSpace")
1507                    .Input("input", 0, DT_FLOAT)
1508                    .Input("crops", 1, DT_INT32)
1509                    .Attr("block_size", 2)
1510                    .Finalize(&op.node_def));
1511 
1512   // croppings not known, but batch size can be computed.
1513   INFER_OK(op, "[4,8,8,3];[2,2]", "[1,?,?,d0_3]");
1514 
1515   // block_size not compatible with batch size
1516   INFER_ERROR("Dimension size must be evenly divisible by", op,
1517               "[5,8,8,3];[2,2]");
1518 
1519   // Unknown croppings means unknown width and height.
1520   INFER_OK(op, "[4,8,8,3];?", "[1,?,?,d0_3]");
1521 
1522   // croppings not correct shape
1523   INFER_ERROR("rank", op, "[4,8,8,3];[4]");
1524   INFER_ERROR("3 and 2", op, "[4,8,8,3];[2,3]");
1525 
1526   Tensor croppings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
1527   op.input_tensors[1] = &croppings;
1528   INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]");
1529 
1530   // Bad croppings values
1531   croppings = test::AsTensor<int32>({100, 2, 3, 4}, {{2, 2}});
1532   op.input_tensors[1] = &croppings;
1533   INFER_ERROR("Negative dimension size caused by subtracting", op,
1534               "[4,8,8,3];[2,2]");
1535   croppings = test::AsTensor<int32>({1, 2, 3, 400}, {{2, 2}});
1536   op.input_tensors[1] = &croppings;
1537   INFER_ERROR("Negative dimension size caused by subtracting", op,
1538               "[4,8,8,3];[2,2]");
1539 
1540   // Negative paddings
1541   croppings = test::AsTensor<int32>({1, -2, 3, 4}, {{2, 2}});
1542   op.input_tensors[1] = &croppings;
1543   INFER_ERROR("cannot be negative", op, "[4,8,8,3];[2,2]");
1544 }
1545 
TEST(ArrayOpsTest,BatchToSpaceND_ShapeFn)1546 TEST(ArrayOpsTest, BatchToSpaceND_ShapeFn) {
1547   ShapeInferenceTestOp op("BatchToSpaceND");
1548   op.input_tensors.resize(3);
1549   TF_ASSERT_OK(NodeDefBuilder("test", "BatchToSpaceND")
1550                    .Input("input", 0, DT_FLOAT)
1551                    .Input("block_shape", 1, DT_INT32)
1552                    .Input("crops", 2, DT_INT32)
1553                    .Finalize(&op.node_def));
1554 
1555   // Verify that input shape and crops shape can be unknown.
1556   INFER_OK(op, "?;[2];?", "?");
1557 
1558   // Only number of input dimensions is known.
1559   INFER_OK(op, "[?,?,?,?];[2];?", "[?,?,?,d0_3]");
1560 
1561   {
1562     // Dimensions are partially known, block_shape known.
1563     Tensor block_shape = test::AsTensor<int32>({2, 3});
1564     op.input_tensors[1] = &block_shape;
1565     INFER_OK(op, "[?,?,?,2];[2];?", "[?,?,?,d0_3]");
1566 
1567     INFER_OK(op, "[18,?,?,2];[2];?", "[3,?,?,d0_3]");
1568 
1569     // Dimensions are partially known, block_shape and crops known.
1570     {
1571       Tensor crops = test::AsTensor<int32>({1, 1, 0, 1}, {{2, 2}});
1572       op.input_tensors[2] = &crops;
1573       INFER_OK(op, "[18,?,2,2];[2];[2,2]", "[3,?,5,d0_3]");
1574       op.input_tensors[2] = nullptr;
1575     }
1576 
1577     // Dimensions are fully known, block_shape and crops are known.
1578     {
1579       Tensor crops = test::AsTensor<int32>({1, 1, 0, 0}, {{2, 2}});
1580       op.input_tensors[2] = &crops;
1581       INFER_OK(op, "[18,2,1,2];[2];[2,2]", "[3,2,3,d0_3]");
1582       op.input_tensors[2] = nullptr;
1583     }
1584 
1585     op.input_tensors[1] = nullptr;
1586   }
1587 
1588   INFER_ERROR("block_shape must have rank 1", op, "?;[1,1];?");
1589   INFER_ERROR("block_shape must have known size", op, "?;[?];?");
1590   INFER_ERROR("rank", op, "[2,2];[2];[2,2]");
1591   INFER_ERROR("rank", op, "[2,2,3];[3];[3,2]");
1592 
1593   {
1594     Tensor block_shape = test::AsTensor<int32>({0, 2});
1595     op.input_tensors[1] = &block_shape;
1596     INFER_ERROR("block_shape must be positive", op, "[1,2,2];[2];[2,2]");
1597     op.input_tensors[1] = nullptr;
1598   }
1599 
1600   {
1601     Tensor block_shape = test::AsTensor<int32>({1, 1});
1602     op.input_tensors[1] = &block_shape;
1603     Tensor paddings = test::AsTensor<int32>({0, -1, 0, 0}, {{2, 2}});
1604     op.input_tensors[2] = &paddings;
1605     INFER_ERROR("crops cannot be negative", op, "[1,2,2];[2];[2,2]");
1606     op.input_tensors[1] = nullptr;
1607     op.input_tensors[2] = nullptr;
1608   }
1609 
1610   // The amount to crop exceeds the padded size.
1611   {
1612     Tensor block_shape = test::AsTensor<int32>({2, 2});
1613     op.input_tensors[1] = &block_shape;
1614     Tensor crops = test::AsTensor<int32>({3, 2, 0, 0}, {{2, 2}});
1615     op.input_tensors[2] = &crops;
1616     INFER_ERROR("Negative", op, "[4,2,3,1];[2];[2,2]");
1617     op.input_tensors[1] = nullptr;
1618     op.input_tensors[2] = nullptr;
1619   }
1620 
1621   // The batch size is not divisible by the product of the block_shape.
1622   {
1623     Tensor block_shape = test::AsTensor<int32>({2, 3});
1624     op.input_tensors[1] = &block_shape;
1625     INFER_ERROR("divisible", op, "[3,1,1,1];[2];[2,2]");
1626     op.input_tensors[1] = nullptr;
1627   }
1628 }
1629 
TEST(ArrayOpsTest,SpaceToDepth_ShapeFn)1630 TEST(ArrayOpsTest, SpaceToDepth_ShapeFn) {
1631   ShapeInferenceTestOp op("SpaceToDepth");
1632   TF_ASSERT_OK(NodeDefBuilder("test", "SpaceToDepth")
1633                    .Input("input", 0, DT_FLOAT)
1634                    .Attr("block_size", 2)
1635                    .Finalize(&op.node_def));
1636 
1637   INFER_OK(op, "[1,2,4,4]", "[d0_0,1,2,16]");
1638 
1639   // block_size not compatible with space
1640   INFER_ERROR("Dimension size must be evenly divisible by 2 but is 3", op,
1641               "[1,3,8,4]");
1642   INFER_ERROR("Dimension size must be evenly divisible by 2 but is 5", op,
1643               "[1,2,5,4]");
1644 
1645   // Unknown depth --> Unknown depth.
1646   INFER_OK(op, "[1,2,4,?]", "[d0_0,1,2,?]");
1647 }
1648 
TEST(ArrayOpsTest,DepthToSpace_ShapeFn)1649 TEST(ArrayOpsTest, DepthToSpace_ShapeFn) {
1650   ShapeInferenceTestOp op("DepthToSpace");
1651   TF_ASSERT_OK(NodeDefBuilder("test", "DepthToSpace")
1652                    .Input("input", 0, DT_FLOAT)
1653                    .Attr("block_size", 2)
1654                    .Finalize(&op.node_def));
1655 
1656   INFER_OK(op, "[1,1,2,16]", "[d0_0,2,4,4]");
1657 
1658   // Bad depth
1659   INFER_ERROR("Dimension size must be evenly divisible by 4 but is 15", op,
1660               "[1,1,2,15]");
1661 
1662   // Unknown depth --> Unknown depth.
1663   INFER_OK(op, "[1,2,4,?]", "[d0_0,4,8,?]");
1664 
1665   // Check another block size.
1666   TF_ASSERT_OK(NodeDefBuilder("test", "DepthToSpace")
1667                    .Input("input", 0, DT_FLOAT)
1668                    .Attr("block_size", 10)
1669                    .Finalize(&op.node_def));
1670   INFER_OK(op, "[1,1,2,200]", "[d0_0,10,20,2]");
1671 }
1672 
TEST(ArrayOpsTest,Slice_ShapeFn)1673 TEST(ArrayOpsTest, Slice_ShapeFn) {
1674   ShapeInferenceTestOp op("Slice");
1675   TF_ASSERT_OK(NodeDefBuilder("test", "Slice")
1676                    .Input("input", 0, DT_FLOAT)
1677                    .Input("begin", 1, DT_INT64)
1678                    .Input("sizes", 2, DT_INT64)
1679                    .Finalize(&op.node_def));
1680 
1681   // Known rank of input and shape of begin/sizes, but unknown values.
1682   // The best we know is the rank of the output.
1683   INFER_OK(op, "[2,3,4,5];[4];[4]", "[?,?,?,?]");
1684 
1685   // Unknown shape of begin/sizes, we still know the rank.
1686   INFER_OK(op, "[2,3,4,5];[?];[?]", "[?,?,?,?]");
1687   // Unknown all around
1688   INFER_OK(op, "?;[?];[?]", "?");
1689   // Can infer based on begin
1690   INFER_OK(op, "?;[4];[?]", "[?,?,?,?]");
1691 
1692   // Bad rank of begin, sizes
1693   INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2,3];[3]");
1694   INFER_ERROR("must be rank 1", op, "[2,3,4,5];[2];[3,4]");
1695   // Length of begin doesn't match input rank
1696   INFER_ERROR("must be rank 2", op, "[2,3,4,5];[2];[2]");
1697 
1698   // Tests with known values.
1699   op.input_tensors.resize(3);
1700   Tensor begin = test::AsTensor<int32>({0, 1, 2, 1});
1701   Tensor sizes = test::AsTensor<int32>({1, 2, 1, 3});
1702   op.input_tensors[1] = &begin;
1703   op.input_tensors[2] = &sizes;
1704   INFER_OK(op, "[2,3,4,5];[4];[4]", "[1,2,1,3]");
1705 
1706   // -1 in sizes means "get the rest"
1707   sizes = test::AsTensor<int32>({-1, -1, 1, -1});
1708   INFER_OK(op, "[2,3,4,5];[4];[4]", "[d0_0,2,1,4]");
1709 
1710   begin = test::AsTensor<int32>({0, 1, 2, 6});
1711   sizes = test::AsTensor<int32>({-1, -1, -1, -1});
1712   INFER_ERROR("Negative dimension size", op, "[2,3,4,5];[4];[4]");
1713 
1714   begin = test::AsTensor<int32>({0, 1, 2, 5});
1715   sizes = test::AsTensor<int32>({-1, -1, -1, -2});
1716   INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
1717 }
1718 
TEST(ArrayOpsTest,StridedSlice_ShapeFn)1719 TEST(ArrayOpsTest, StridedSlice_ShapeFn) {
1720   ShapeInferenceTestOp op("StridedSlice");
1721   TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice")
1722                    .Input("input", 0, DT_FLOAT)
1723                    .Input("begin", 1, DT_INT32)
1724                    .Input("end", 2, DT_INT32)
1725                    .Input("strides", 3, DT_INT32)
1726                    .Attr("shrink_axis_mask", 1)
1727                    .Finalize(&op.node_def));
1728   op.input_tensors.resize(4);
1729   Tensor strides = test::AsTensor<int32>({1});
1730   op.input_tensors[3] = &strides;
1731   // Slicing on the 0-th dimension.
1732   INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]");
1733   // Slicing on the 0-th dimension. This time some of the result dimension is 0.
1734   INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]");
1735 }
1736 
TEST(ArrayOpsTest,StridedSliceGrad_ShapeFn)1737 TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
1738   ShapeInferenceTestOp op("StridedSliceGrad");
1739   op.input_tensors.resize(5);
1740   INFER_OK(op, "?;?;?;?;?", "?");
1741   INFER_OK(op, "[?];?;?;?;?", "?");
1742   INFER_OK(op, "[4];?;?;?;?", "[?,?,?,?]");
1743 
1744   Tensor in_t = test::AsTensor<int32>({1, 2, 3, 4});
1745   op.input_tensors[0] = &in_t;
1746   INFER_OK(op, "[4];?;?;?;?", "[1,2,3,4]");
1747 }
1748 
TEST(ArrayOpsTest,UnchangedWithQuantizationScalars_ShapeFn)1749 TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) {
1750   for (const char* op_name : {"Dequantize", "FakeQuantWithMinMaxVars"}) {
1751     ShapeInferenceTestOp op(op_name);
1752     if (op_name[0] == 'D') {
1753       TF_ASSERT_OK(NodeDefBuilder("test", "Dequantize")
1754                        .Input("input", 0, DT_QINT8)
1755                        .Input("input_min", 1, DT_FLOAT)
1756                        .Input("input_max", 2, DT_FLOAT)
1757                        .Attr("T", DataTypeToEnum<qint8>::v())
1758                        .Attr("mode", "SCALED")
1759                        .Attr("axis", -1)
1760                        .Finalize(&op.node_def));
1761     }
1762     INFER_OK(op, "?;?;?", "in0");
1763     INFER_OK(op, "[1,?,3];[];[]", "in0");
1764 
1765     // Rank check scalars.
1766     INFER_ERROR("be rank 0", op, "[1,?,3];[1];[]");
1767     INFER_ERROR("be rank 0", op, "[1,?,3];[];[1]");
1768   }
1769 }
1770 
TEST(ArrayOpsTest,FakeQuantWithMinMaxVarsPerChannel)1771 TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
1772   ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
1773 
1774   INFER_OK(op, "?;?;?", "in0");
1775   INFER_OK(op, "[?];?;?", "in0");
1776   INFER_OK(op, "[1,?,3];[3];[3]", "in0");
1777   INFER_OK(op, "[3];[3];[3]", "in0");
1778 
1779   // Rank check vectors.
1780   INFER_ERROR("be rank 1", op, "[1,?,3];[1];[]");
1781   INFER_ERROR("be rank 1", op, "[1,?,3];[];[1]");
1782 
1783   // Vectors must match each other, and match last dim of input.
1784   INFER_ERROR("must be equal", op, "[1,?,3];[2];[?]");
1785   INFER_ERROR("must be equal", op, "[1,?,3];[?];[2]");
1786   INFER_ERROR("must be equal", op, "[1,?,?];[1];[2]");
1787   INFER_ERROR("must be equal", op, "[5];[4];[?]");
1788 }
1789 
TEST(ArrayOpsTest,FakeQuantWithMinMaxVarsPerChannelGradient)1790 TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
1791   ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
1792 
1793   INFER_OK(op, "?;?;?;?", "in0;[?];[?]");
1794   INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
1795   INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
1796   INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
1797 
1798   // Rank check vectors.
1799   INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
1800   INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
1801   INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
1802   INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
1803 
1804   // Vectors must match each other, and match last dim of input.
1805   INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
1806   INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
1807 }
1808 
TEST(ArrayOpsTest,QuantizedConcat_ShapeFn)1809 TEST(ArrayOpsTest, QuantizedConcat_ShapeFn) {
1810   ShapeInferenceTestOp op("QuantizedConcat");
1811   auto set_n = [&op](int n) {
1812     std::vector<NodeDefBuilder::NodeOut> src_list;
1813     std::vector<NodeDefBuilder::NodeOut> limit_list;
1814     for (int i = 0; i < n; ++i) {
1815       src_list.emplace_back("a", 0, DT_QUINT8);
1816       limit_list.emplace_back("b", 0, DT_FLOAT);
1817     }
1818     TF_ASSERT_OK(NodeDefBuilder("test", "QuantizedConcat")
1819                      .Input({"concat_dim", 0, DT_INT32})
1820                      .Input(src_list)
1821                      .Input(limit_list)
1822                      .Input(limit_list)
1823                      .Attr("N", n)
1824                      .Finalize(&op.node_def));
1825   };
1826 
1827   // Confirm dimension[0] of the input (the concat_dim) is a scalar.
1828   set_n(1);
1829   INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1];?;?;?");
1830 
1831   // Last 2*<N> are all scalars.
1832   set_n(2);
1833   INFER_ERROR("must be rank 0", op, "[];?;?;?;?;?;[1]");
1834   INFER_ERROR("must be rank 0", op, "[];?;?;?;?;[1];?");
1835   INFER_ERROR("must be rank 0", op, "[];?;?;?;[1];?;?");
1836   INFER_ERROR("must be rank 0", op, "[];?;?;[1];?;?;?");
1837 
1838   // First is concat dim; next N must be compatible for concat.
1839   set_n(2);
1840   INFER_ERROR("must be rank 2", op, "[];[1,2];[1,2,3];?;?;?;?");
1841   INFER_OK(op, "[];[1,2];[1,3];?;?;?;?", "[?,?];[];[]");
1842 
1843   // Test when the concat_dim tensor is known. The concatenated dimension is
1844   // summed across all input tensors, and other dimensions are merged.
1845   Tensor concat_dim_t;
1846   op.input_tensors.push_back(&concat_dim_t);
1847   set_n(2);
1848   concat_dim_t = test::AsScalar(0);  // Sum dim 0, merge the other two dims.
1849   INFER_OK(op, "[];[100,2,?];[10,?,3];?;?;?;?", "[110,d1_1,d2_2];[];[]");
1850   INFER_ERROR("Dimension 1 in both shapes must be equal, but are 5 and 3", op,
1851               "[];[100,2,5];[10,?,3];?;?;?;?");
1852   // Note that other cases of concat are covered in the Concat tests.
1853 }
1854 
TEST(StateOpsTest,_ParallelConcatStart_ShapeFn)1855 TEST(StateOpsTest, _ParallelConcatStart_ShapeFn) {
1856   ShapeInferenceTestOp op("_ParallelConcatStart");
1857   TensorShape shape({1, 2, 3});
1858   TensorShapeProto shape_proto;
1859   shape.AsProto(&shape_proto);
1860   TF_ASSERT_OK(NodeDefBuilder("test", "_ParallelConcatStart")
1861                    .Attr("shape", shape_proto)
1862                    .Attr("dtype", DT_FLOAT)
1863                    .Finalize(&op.node_def));
1864   INFER_OK(op, "", "[1,2,3]");
1865 }
1866 
1867 }  // end namespace tensorflow
1868