• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/fake_input.h"
17 #include "tensorflow/core/framework/node_def_builder.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference_testutil.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 
TEST(NNOpsTest,TopK_ShapeFn)26 TEST(NNOpsTest, TopK_ShapeFn) {
27   ShapeInferenceTestOp op("TopK");
28   auto set_k = [&op](int k) {
29     TF_ASSERT_OK(NodeDefBuilder("test", "Pack")
30                      .Input({{"a", 0, DT_FLOAT}})
31                      .Attr("k", k)
32                      .Finalize(&op.node_def));
33   };
34 
35   set_k(20);
36   // With known input, each output is an unknown shape.
37   INFER_OK(op, "?", "?;?");
38   // With vector input, each output is [k].
39   INFER_OK(op, "[20]", "[20];[20]");
40   INFER_OK(op, "[21]", "[20];[20]");
41 
42   // With input rank 3, each output is the two first 2 dims of input, plus k.
43   INFER_OK(op, "[1,?,21]", "[d0_0,d0_1,20];[d0_0,d0_1,20]");
44   // With input rank 4, each output is the two first 3 dims of input, plus k.
45   INFER_OK(op, "[1,?,21,?]", "[d0_0,d0_1,d0_2,20];[d0_0,d0_1,d0_2,20]");
46 
47   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
48   INFER_ERROR("input must have last dimension >= k = 20 but is 1", op, "[1]");
49   INFER_ERROR("input must have last dimension >= k = 20 but is 4", op,
50               "[1,2,3,4]");
51   set_k(-1);
52   INFER_ERROR("Need k >= 0, got -1", op, "[1,2,3,4]");
53 }
54 
TEST(NNOpsTest,TopKV2_ShapeFn)55 TEST(NNOpsTest, TopKV2_ShapeFn) {
56   ShapeInferenceTestOp op("TopKV2");
57   op.input_tensors.resize(2);
58 
59   Tensor k_t;
60   op.input_tensors[1] = &k_t;
61 
62   k_t = test::AsScalar<int32>(20);
63   // With known input, each output is an unknown shape.
64   INFER_OK(op, "?;[]", "?;?");
65   // With vector input, each output is [k].
66   INFER_OK(op, "[20];[]", "[20];[20]");
67 
68   // With input rank 3, each output is the two first 2 dims of input, plus k.
69   INFER_OK(op, "[1,?,21];[]", "[d0_0,d0_1,20];[d0_0,d0_1,20]");
70   // With input rank 4, each output is the two first 3 dims of input, plus k.
71   INFER_OK(op, "[1,?,21,?];[]", "[d0_0,d0_1,d0_2,20];[d0_0,d0_1,d0_2,20]");
72 
73   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
74   INFER_ERROR("input must have last dimension >= k = 20 but is 1", op,
75               "[1];[]");
76   INFER_ERROR("input must have last dimension >= k = 20 but is 4", op,
77               "[1,2,3,4];[]");
78   k_t = test::AsScalar<int32>(-1);
79   INFER_ERROR(
80       "Dimension size, given by scalar input 1, must be non-negative but is -1",
81       op, "[1,2,3,4];[]");
82 }
83 
TEST(NNOpsTest,NthElement_ShapeFn)84 TEST(NNOpsTest, NthElement_ShapeFn) {
85   ShapeInferenceTestOp op("NthElement");
86   op.input_tensors.resize(2);
87 
88   Tensor n_t;
89   op.input_tensors[1] = &n_t;
90   n_t = test::AsScalar<int32>(20);
91 
92   INFER_OK(op, "?;[]", "?");
93   INFER_OK(op, "[21];[]", "[]");
94   INFER_OK(op, "[2,?,?];[]", "[d0_0,d0_1]");
95   INFER_OK(op, "[?,3,?,21];[]", "[d0_0,d0_1,d0_2]");
96 
97   INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
98   INFER_ERROR("Input must have last dimension > n = 20 but is 1", op, "[1];[]");
99   INFER_ERROR("Input must have last dimension > n = 20 but is 20", op,
100               "[1,2,3,20];[]");
101   n_t = test::AsScalar<int32>(-1);
102   INFER_ERROR(
103       "Dimension size, given by scalar input 1, must be non-negative but is -1",
104       op, "[1,2,3,4];[]");
105 }
106 
TEST(NNOpsTest,BatchNormWithGlobalNormalization_ShapeFn)107 TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
108   ShapeInferenceTestOp op("BatchNormWithGlobalNormalization");
109 
110   // Test rank errors.
111   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
112   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
113   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
114   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
115   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
116 
117   // last dim of first input is merged with the single dim in other 4 inputs.
118   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]");
119   INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0]");
120   INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0]");
121   INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0]");
122   INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0]");
123   INFER_OK(op, "[1,2,3,4];[4];[4];[4];[4]",
124            "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0]");
125 }
126 
TEST(NNOpsTest,QuantizedBatchNormWithGlobalNormalization_ShapeFn)127 TEST(NNOpsTest, QuantizedBatchNormWithGlobalNormalization_ShapeFn) {
128   // These are the same tests as BatchNormWithGlobalNormalization tests, but
129   // with extra scalar inputs and outputs for the mins and maxes.
130 
131   ShapeInferenceTestOp op("QuantizedBatchNormWithGlobalNormalization");
132 
133   // Test rank errors.
134   INFER_ERROR("Shape must be rank 4 but is rank 3", op,
135               "[1,2,3];?;?;?;?;?;?;?;?;?;?;?;?;?;?");
136   INFER_ERROR("Shape must be rank 1 but is rank 3", op,
137               "?;?;?;[1,2,3];?;?;?;?;?;?;?;?;?;?;?");
138   INFER_ERROR("Shape must be rank 1 but is rank 3", op,
139               "?;?;?;?;?;?;[1,2,3];?;?;?;?;?;?;?;?");
140   INFER_ERROR("Shape must be rank 1 but is rank 3", op,
141               "?;?;?;?;?;?;?;?;?;[1,2,3];?;?;?;?;?");
142   INFER_ERROR("Shape must be rank 1 but is rank 3", op,
143               "?;?;?;?;?;?;?;?;?;?;?;?;[1,2,3];?;?");
144 
145   // last dim of first input is merged with the single dim in other 4 inputs.
146   INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];?;[];[]", "[?,?,?,?];[];[]");
147   INFER_OK(op, "?;[];[];[1];[];[];?;[];[];?;[];[];?;[];[]",
148            "[?,?,?,d3_0];[];[]");
149   INFER_OK(op, "?;[];[];?;[];[];[1];[];[];?;[];[];?;[];[]",
150            "[?,?,?,d6_0];[];[]");
151   INFER_OK(op, "?;[];[];?;[];[];?;[];[];[1];[];[];?;[];[]",
152            "[?,?,?,d9_0];[];[]");
153   INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];[1];[];[]",
154            "[?,?,?,d12_0];[];[]");
155   INFER_OK(op, "[1,2,3,4];[];[];[4];[];[];[4];[];[];[4];[];[];[4];[];[]",
156            "[d0_0,d0_1,d0_2,d0_3|d3_0|d6_0|d9_0|d12_0];[];[]");
157 }
158 
TEST(NNOpsTest,BatchNormWithGlobalNormalizationGrad_ShapeFn)159 TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
160   ShapeInferenceTestOp op("BatchNormWithGlobalNormalizationGrad");
161 
162   // Test rank errors.
163   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
164   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
165   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
166   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
167   INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op,
168               "?;?;?;?;[1,2,3]");
169 
170   // The first output comes from the first and last inputs merged together.
171   // Other inputs are merged with the last dim of that merge result, and that
172   // merged vector dim is the last 4 outputs.
173   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
174   INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
175   INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
176   INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
177   INFER_OK(op, "[1,?,3,?];[?];[?];[?];[?,2,?,4]",
178            "[d0_0,d4_1,d0_2,d4_3];[d4_3];[d4_3];[d4_3];[d4_3]");
179 }
180 
TEST(NNOpsTest,FusedBatchNorm_ShapeFn)181 TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
182   ShapeInferenceTestOp op("FusedBatchNorm");
183 
184   auto set_op = [&op](bool is_training, float exponential_avg_factor,
185                       string data_format) {
186     TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
187                      .Input(FakeInput(DT_FLOAT))
188                      .Input(FakeInput(DT_FLOAT))
189                      .Input(FakeInput(DT_FLOAT))
190                      .Input(FakeInput(DT_FLOAT))
191                      .Input(FakeInput(DT_FLOAT))
192                      .Attr("data_format", data_format)
193                      .Attr("is_training", is_training)
194                      .Attr("exponential_avg_factor", exponential_avg_factor)
195                      .Finalize(&op.node_def));
196   };
197 
198   set_op(true, 1.0, "NHWC");
199   // Test rank errors.
200   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
201   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
202   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
203   // Channel dim of first input is merged with the single dim in other 4 inputs.
204   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
205   INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
206   INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
207   INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
208            "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
209            "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
210            "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
211 
212   set_op(true, 0.5, "NHWC");
213   // Test rank errors.
214   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
215   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
216   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
217   // Channel dim of first input is merged with the single dim in other 4 inputs.
218   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
219   INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
220   INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
221   INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
222            "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
223            "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
224            "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
225 
226   set_op(true, 1.0, "NCHW");
227   // Test rank errors.
228   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
229   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
230   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
231   // Channel dim of first input is merged with the single dim in other 4 inputs.
232   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
233   INFER_OK(op, "?;[1];?;?;?", "[?,d1_0,?,?];[d1_0];[d1_0];[d1_0];[d1_0]");
234   INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
235   INFER_OK(op, "[1,4,2,3];[4];[4];?;?",
236            "[d0_0,d0_1|d1_0|d2_0,d0_2,d0_3];"
237            "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
238            "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
239 
240   set_op(false, 1.0, "NHWC");
241   // Test rank errors.
242   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
243   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
244   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
245   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
246   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
247   // Channel dim of first input is merged with the single dim in other 4 inputs.
248   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
249   INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
250   INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
251   INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
252   INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[d4_0];[d4_0]");
253   INFER_OK(op, "[1,2,3,4];[4];[4];[4];[4]",
254            "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0];"
255            "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
256            "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
257 
258   set_op(false, 1.0, "NCHW");
259   // Test rank errors.
260   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
261   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
262   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
263   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
264   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
265   // Channel dim of first input is merged with the single dim in other 4 inputs.
266   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
267   INFER_OK(op, "?;[1];?;?;?", "[?,d1_0,?,?];[d1_0];[d1_0];[d1_0];[d1_0]");
268   INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
269   INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[d3_0];[d3_0]");
270   INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[d4_0];[d4_0]");
271   INFER_OK(op, "[1,4,2,3];[4];[4];[4];[4]",
272            "[d0_0,d0_1|d1_0|d2_0|d3_0|d4_0,d0_2,d0_3];"
273            "[d0_1|d1_0|d2_0|d3_0|d4_0];[d0_1|d1_0|d2_0|d3_0|d4_0];"
274            "[d0_1|d1_0|d2_0|d3_0|d4_0];[d0_1|d1_0|d2_0|d3_0|d4_0]");
275 }
276 
TEST(NNOpsTest,FusedBatchNormGrad_ShapeFn)277 TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
278   ShapeInferenceTestOp op("FusedBatchNormGrad");
279   auto set_op = [&op](string data_format) {
280     TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNormGrad")
281                      .Input(FakeInput(DT_FLOAT))
282                      .Input(FakeInput(DT_FLOAT))
283                      .Input(FakeInput(DT_FLOAT))
284                      .Input(FakeInput(DT_FLOAT))
285                      .Input(FakeInput(DT_FLOAT))
286                      .Attr("data_format", data_format)
287                      .Finalize(&op.node_def));
288   };
289 
290   set_op("NCHW");
291   // Test rank errors.
292   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
293   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
294   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
295   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
296   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
297   // Channel dim of first input is merged with the single dim in other 4 inputs.
298   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
299   INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[0];[0]");
300   INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[0];[0]");
301   INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[0];[0]");
302   INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
303            "[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
304            "[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
305 
306   set_op("NHWC");
307   // Test rank errors.
308   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
309   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
310   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
311   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
312   INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
313   // Channel dim of first input is merged with the single dim in other 4 inputs.
314   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
315   INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
316   INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
317   INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
318   INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
319            "[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
320            "[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
321 }
322 
TEST(NNOpsTest,Conv2DBackpropInput_ShapeFn)323 TEST(NNOpsTest, Conv2DBackpropInput_ShapeFn) {
324   ShapeInferenceTestOp op("Conv2DBackpropInput");
325 
326   // Test rank error.
327   INFER_ERROR("input_sizes to contain 4 values or 2 values", op,
328               "[3];[?,?,?,?];[?,?,?,?]");
329   INFER_ERROR("Shape must be rank 4 but is rank 3", op,
330               "[4];[?,?,?,?];[?,?,?]");
331 
332   // When input_sizes is a 4D shape and the convolution is grouped, the channel
333   // size of the input grad doesn't always equal the input channel size of the
334   // filter. So, when input_sizes is a 4D shape, the channel size of the input
335   // grad is determined by the content of input_sizes.
336   INFER_OK(op, "[4];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,?]");
337   // When input_sizes is a 2D shape, the channel size of the input grad always
338   // matches the filter shape.
339   INFER_OK(op, "[2];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,d1_2]");
340 }
341 
TEST(NNOpsTest,Conv3DBackpropInput_ShapeFn)342 TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
343   ShapeInferenceTestOp op("Conv3DBackpropInput");
344 
345   // Test rank error.
346   INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?");
347 
348   // input[1] is transferred to output after asserting its rank.
349   INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
350   INFER_OK(op, "[?,?,?,?,?];?;?", "in0");
351   INFER_OK(op, "[?,2,?,4,?];?;?", "in0");
352 }
353 
TEST(NNOpsTest,Conv3DBackpropFilter_ShapeFn)354 TEST(NNOpsTest, Conv3DBackpropFilter_ShapeFn) {
355   ShapeInferenceTestOp op("Conv3DBackpropFilter");
356 
357   // Test rank error.
358   INFER_ERROR("Shape must be rank 5 but is rank 3", op, "?;[1,2,3];?");
359 
360   // input[1] is transferred to output after asserting its rank.
361   INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
362   INFER_OK(op, "?;[?,?,?,?,?];?", "in1");
363   INFER_OK(op, "?;[?,2,?,4,?];?", "in1");
364 }
365 
TEST(NNOpsTest,MaxPool3DGrad_ShapeFn)366 TEST(NNOpsTest, MaxPool3DGrad_ShapeFn) {
367   ShapeInferenceTestOp op("MaxPool3DGrad");
368 
369   // Test rank error.
370   INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?");
371 
372   // input[0] is transferred to output after asserting its rank.
373   INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
374   INFER_OK(op, "[?,?,?,?,?];?;?", "in0");
375   INFER_OK(op, "[?,2,?,4,?];?;?", "in0");
376 }
377 
TEST(NNOpsTest,LRNGrad_ShapeFn)378 TEST(NNOpsTest, LRNGrad_ShapeFn) {
379   ShapeInferenceTestOp op("LRNGrad");
380 
381   // LRN Grad is a merge of all three inputs, of rank 4.
382   INFER_OK(op, "[1,?,?,4];[?,2,?,?];[?,?,3,?]", "[d0_0,d1_1,d2_2,d0_3]");
383 
384   // Test rank errors.
385   INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?");
386   INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;[1,2,3];?");
387   INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;?;[1,2,3]");
388 }
389 
TEST(NNOpsTest,MaxPoolGrad_ShapeFn)390 TEST(NNOpsTest, MaxPoolGrad_ShapeFn) {
391   for (const char* op_name : {"MaxPoolGrad", "MaxPoolGradWithArgmax"}) {
392     ShapeInferenceTestOp op(op_name);
393 
394     // Test rank error.
395     INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?");
396 
397     // input[0] is transferred to output after asserting its rank.
398     INFER_OK(op, "?;?;?", "[?,?,?,?]");
399     INFER_OK(op, "[?,?,?,?];?;?", "in0");
400     INFER_OK(op, "[?,2,?,4];?;?", "in0");
401   }
402 }
403 
TEST(NNOpsTest,Dilation2DBackpropInput_ShapeFn)404 TEST(NNOpsTest, Dilation2DBackpropInput_ShapeFn) {
405   ShapeInferenceTestOp op("Dilation2DBackpropInput");
406 
407   // input[0] is transferred to output.
408   INFER_OK(op, "?;?;?", "in0");
409   INFER_OK(op, "?;[?,?,?,?,?];?", "in0");
410   INFER_OK(op, "?;[?,2,?,4,?];?", "in0");
411 }
412 
TEST(NNOpsTest,Dilation2DBackpropFilter_ShapeFn)413 TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) {
414   ShapeInferenceTestOp op("Dilation2DBackpropFilter");
415 
416   // input[1] is transferred to output.
417   INFER_OK(op, "?;?;?", "in1");
418   INFER_OK(op, "?;[?,?,?,?,?];?", "in1");
419   INFER_OK(op, "?;[?,2,?,4,?];?", "in1");
420 }
421 
TEST(NNOpsTest,MergeBothInputs_ShapeFn)422 TEST(NNOpsTest, MergeBothInputs_ShapeFn) {
423   for (const char* op_name : {"ReluGrad", "Relu6Grad", "EluGrad", "SeluGrad",
424                               "SoftplusGrad", "SoftsignGrad"}) {
425     ShapeInferenceTestOp op(op_name);
426 
427     INFER_OK(op, "?;?", "in0|in1");
428     INFER_OK(op, "?;[1,?,3]", "in1");
429     INFER_OK(op, "[1,?,3];?", "in0");
430     INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
431     INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
432                 "[1,3];[?,2]");
433   }
434 }
435 
TEST(NNOpsTest,SoftmaxCrossEntropyWithLogits_ShapeFn)436 TEST(NNOpsTest, SoftmaxCrossEntropyWithLogits_ShapeFn) {
437   ShapeInferenceTestOp op("SoftmaxCrossEntropyWithLogits");
438 
439   // Inputs are [batch_size,N] and [batch_size,N], and outputs are [batch_size]
440   // and
441   // [batch_size,N].
442   INFER_OK(op, "?;?", "[?];[?,?]");
443   INFER_OK(op, "[?,?];[?,?]", "[d0_0|d1_0];in0|in1");
444   INFER_OK(op, "[1,2];[?,2]", "[d0_0];in0");
445   INFER_OK(op, "[1,?];[?,2]", "[d0_0];[d0_0,d0_1|d1_1]");
446   INFER_OK(op, "[?,2];[1,2]", "[d1_0];in1");
447 
448   INFER_ERROR("Shape must be broadcasted with rank 2", op, "[1,2,3];?");
449   INFER_ERROR("Shape must be broadcasted with rank 2", op, "?;[1,2,3]");
450 
451   // Broadcast example
452   // [1,4] and [2,4] are broadcasted to [2,4]
453   INFER_OK(op, "[1,4];[2,4]", "[d1_0];[d1_0,d0_1|d1_1]");
454   // [2,4] and [2,1] are broadcasted to [2,4]
455   INFER_OK(op, "[2,4];[2,1]", "[d0_0];[d0_0|d1_0,d0_1]");
456   // [1,?] and [2,4] are broadcasted to [2,4]
457   INFER_OK(op, "[1,?];[2,4]", "[d1_0];[d1_0,d0_1|d1_1]");
458   // [2,4] and [?,1] are broadcasted to [2,4]
459   INFER_OK(op, "[2,4];[?,1]", "[d0_0];[d0_0|d1_0,d0_1]");
460 }
461 
TEST(NNOpsTest,SparseSoftmaxCrossEntropyWithLogits_ShapeFn)462 TEST(NNOpsTest, SparseSoftmaxCrossEntropyWithLogits_ShapeFn) {
463   ShapeInferenceTestOp op("SparseSoftmaxCrossEntropyWithLogits");
464 
465   // Inputs are [batch_size,N] and [batch_size], and outputs are [batch_size]
466   // and [batch_size,N].
467   INFER_OK(op, "?;?", "[?];[?,?]");
468   INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]");
469   INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]");
470   INFER_OK(op, "[?,2];[1]", "[d1_0];[d1_0,d0_1]");
471 
472   INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
473   INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
474   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
475 }
476 
TEST(NNOpsTest,InTopK_ShapeFn)477 TEST(NNOpsTest, InTopK_ShapeFn) {
478   ShapeInferenceTestOp op("InTopK");
479 
480   // Inputs are [batch_size,N] and [batch_size], and output is [batch_size].
481   INFER_OK(op, "?;?", "[?]");
482   INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0]");
483   INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0]");
484   INFER_OK(op, "[?,2];[1]", "[d1_0]");
485 
486   INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
487   INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
488   INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
489 }
490 
TEST(NNOpsTest,Dilation2DShapeTest)491 TEST(NNOpsTest, Dilation2DShapeTest) {
492   ShapeInferenceTestOp op("Dilation2D");
493   auto set_op = [&op](const std::vector<int32>& strides,
494                       const std::vector<int32>& rates, const string& padding) {
495     TF_ASSERT_OK(NodeDefBuilder("test", "Dilation2D")
496                      .Input("input", 0, DT_FLOAT)
497                      .Input("filter", 0, DT_FLOAT)
498                      .Attr("strides", strides)
499                      .Attr("rates", rates)
500                      .Attr("padding", padding)
501                      .Finalize(&op.node_def));
502   };
503 
504   // rate rows and cols is 1, so filter_rows and cols are unchanged.
505   // We have a 1x1 filter so the output is still 2x2.
506   set_op({1, 1, 1, 1}, {1, 1, 1, 1}, "VALID");
507   INFER_OK(op, "[1,2,2,2];[1,1,2]", "[d0_0,2,2,d1_2]");
508 
509   // rate rows and cols is 2, so filter_rows and cols are changed to
510   // be 2 + (2 - 1) = 3.  7x7 input with 3x3 filter and 1x1 stride
511   // gives a 5x5 output.
512   set_op({1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
513   INFER_OK(op, "[1,7,7,2];[2,2,2]", "[d0_0,5,5,d1_2]");
514 }
515 
TEST(NNOpsTest,FractionalPool_ShapeFn)516 TEST(NNOpsTest, FractionalPool_ShapeFn) {
517   for (const char* op_name : {"FractionalAvgPool", "FractionalMaxPool"}) {
518     ShapeInferenceTestOp op(op_name);
519     auto set_op = [&op, op_name](const std::vector<float>& pooling_ratio) {
520       TF_ASSERT_OK(NodeDefBuilder("test", op_name)
521                        .Input("input", 0, DT_FLOAT)
522                        .Attr("pooling_ratio", pooling_ratio)
523                        .Finalize(&op.node_def));
524     };
525 
526     set_op(std::vector<float>{2.0f, 1, 1 / 1.5f, 1 / 2.0f});
527 
528     // Rank check.
529     INFER_ERROR("must be rank 4", op, "[?,?,?]");
530 
531     // Unknown inputs.
532     INFER_OK(op, "?", "[?,?,?,?];[?];[?]");
533     INFER_OK(op, "[?,?,?,?]", "[?,?,?,?];[?];[?]");
534 
535     INFER_OK(op, "[10,20,30,40]", "[5,20,45,80];[20];[45]");
536     INFER_OK(op, "[?,20,30,40]", "[?,20,45,80];[20];[45]");
537     INFER_OK(op, "[10,?,30,40]", "[5,?,45,80];[?];[45]");
538     INFER_OK(op, "[10,20,?,40]", "[5,20,?,80];[20];[?]");
539     INFER_OK(op, "[10,20,30,?]", "[5,20,45,?];[20];[45]");
540 
541     // Wrong number of values for pooling_ratio.
542     set_op(std::vector<float>{.5, 1.0, 1.5});
543     INFER_ERROR("pooling_ratio field", op, "?");
544     set_op(std::vector<float>{1, 2, 3, 4, 5});
545     INFER_ERROR("pooling_ratio field", op, "?");
546 
547     // Check dim size >= 0.
548     set_op(std::vector<float>{-1, 2, 3, 4});
549     INFER_ERROR("is negative", op, "[1,2,3,4]");
550   }
551 }
552 
TEST(NNOpsTest,FractionalMaxPoolGrad)553 TEST(NNOpsTest, FractionalMaxPoolGrad) {
554   ShapeInferenceTestOp op("FractionalMaxPoolGrad");
555 
556   // Note that the shape fn only uses input[0] for computation.
557   INFER_ERROR("must be rank 4", op, "[?,?,?];?;?;?;?");
558   INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]");
559   INFER_OK(op, "[?,?,3,4];?;?;?;?", "in0");
560 }
561 
TEST(NNOpsTest,FractionalAvgPoolGrad)562 TEST(NNOpsTest, FractionalAvgPoolGrad) {
563   ShapeInferenceTestOp op("FractionalAvgPoolGrad");
564   op.input_tensors.resize(1);
565 
566   // With no input shape tensor, returns unknown of rank 4.
567   INFER_OK(op, "?;?;?;?", "[?,?,?,?]");
568 
569   // When input tensor is known, its values determine output shape.
570   std::vector<int32> shape{1, 2, 3, 4};
571   Tensor shape_t = test::AsTensor<int32>(shape);
572   op.input_tensors[0] = &shape_t;
573   INFER_OK(op, "[5];?;?;?", "[1,2,3,4]");
574 }
575 
576 }  // end namespace tensorflow
577