1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/fake_input.h"
17 #include "tensorflow/core/framework/node_def_builder.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference_testutil.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/lib/core/status_test_util.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25
TEST(NNOpsTest,TopK_ShapeFn)26 TEST(NNOpsTest, TopK_ShapeFn) {
27 ShapeInferenceTestOp op("TopK");
28 auto set_k = [&op](int k) {
29 TF_ASSERT_OK(NodeDefBuilder("test", "Pack")
30 .Input({{"a", 0, DT_FLOAT}})
31 .Attr("k", k)
32 .Finalize(&op.node_def));
33 };
34
35 set_k(20);
36 // With known input, each output is an unknown shape.
37 INFER_OK(op, "?", "?;?");
38 // With vector input, each output is [k].
39 INFER_OK(op, "[20]", "[20];[20]");
40 INFER_OK(op, "[21]", "[20];[20]");
41
42 // With input rank 3, each output is the two first 2 dims of input, plus k.
43 INFER_OK(op, "[1,?,21]", "[d0_0,d0_1,20];[d0_0,d0_1,20]");
44 // With input rank 4, each output is the two first 3 dims of input, plus k.
45 INFER_OK(op, "[1,?,21,?]", "[d0_0,d0_1,d0_2,20];[d0_0,d0_1,d0_2,20]");
46
47 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
48 INFER_ERROR("input must have last dimension >= k = 20 but is 1", op, "[1]");
49 INFER_ERROR("input must have last dimension >= k = 20 but is 4", op,
50 "[1,2,3,4]");
51 set_k(-1);
52 INFER_ERROR("Need k >= 0, got -1", op, "[1,2,3,4]");
53 }
54
TEST(NNOpsTest,TopKV2_ShapeFn)55 TEST(NNOpsTest, TopKV2_ShapeFn) {
56 ShapeInferenceTestOp op("TopKV2");
57 op.input_tensors.resize(2);
58
59 Tensor k_t;
60 op.input_tensors[1] = &k_t;
61
62 k_t = test::AsScalar<int32>(20);
63 // With known input, each output is an unknown shape.
64 INFER_OK(op, "?;[]", "?;?");
65 // With vector input, each output is [k].
66 INFER_OK(op, "[20];[]", "[20];[20]");
67
68 // With input rank 3, each output is the two first 2 dims of input, plus k.
69 INFER_OK(op, "[1,?,21];[]", "[d0_0,d0_1,20];[d0_0,d0_1,20]");
70 // With input rank 4, each output is the two first 3 dims of input, plus k.
71 INFER_OK(op, "[1,?,21,?];[]", "[d0_0,d0_1,d0_2,20];[d0_0,d0_1,d0_2,20]");
72
73 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
74 INFER_ERROR("input must have last dimension >= k = 20 but is 1", op,
75 "[1];[]");
76 INFER_ERROR("input must have last dimension >= k = 20 but is 4", op,
77 "[1,2,3,4];[]");
78 k_t = test::AsScalar<int32>(-1);
79 INFER_ERROR(
80 "Dimension size, given by scalar input 1, must be non-negative but is -1",
81 op, "[1,2,3,4];[]");
82 }
83
TEST(NNOpsTest,NthElement_ShapeFn)84 TEST(NNOpsTest, NthElement_ShapeFn) {
85 ShapeInferenceTestOp op("NthElement");
86 op.input_tensors.resize(2);
87
88 Tensor n_t;
89 op.input_tensors[1] = &n_t;
90 n_t = test::AsScalar<int32>(20);
91
92 INFER_OK(op, "?;[]", "?");
93 INFER_OK(op, "[21];[]", "[]");
94 INFER_OK(op, "[2,?,?];[]", "[d0_0,d0_1]");
95 INFER_OK(op, "[?,3,?,21];[]", "[d0_0,d0_1,d0_2]");
96
97 INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[];[]");
98 INFER_ERROR("Input must have last dimension > n = 20 but is 1", op, "[1];[]");
99 INFER_ERROR("Input must have last dimension > n = 20 but is 20", op,
100 "[1,2,3,20];[]");
101 n_t = test::AsScalar<int32>(-1);
102 INFER_ERROR(
103 "Dimension size, given by scalar input 1, must be non-negative but is -1",
104 op, "[1,2,3,4];[]");
105 }
106
TEST(NNOpsTest,BatchNormWithGlobalNormalization_ShapeFn)107 TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
108 ShapeInferenceTestOp op("BatchNormWithGlobalNormalization");
109
110 // Test rank errors.
111 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
112 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
113 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
114 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
115 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
116
117 // last dim of first input is merged with the single dim in other 4 inputs.
118 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]");
119 INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0]");
120 INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0]");
121 INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0]");
122 INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0]");
123 INFER_OK(op, "[1,2,3,4];[4];[4];[4];[4]",
124 "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0]");
125 }
126
TEST(NNOpsTest,QuantizedBatchNormWithGlobalNormalization_ShapeFn)127 TEST(NNOpsTest, QuantizedBatchNormWithGlobalNormalization_ShapeFn) {
128 // These are the same tests as BatchNormWithGlobalNormalization tests, but
129 // with extra scalar inputs and outputs for the mins and maxes.
130
131 ShapeInferenceTestOp op("QuantizedBatchNormWithGlobalNormalization");
132
133 // Test rank errors.
134 INFER_ERROR("Shape must be rank 4 but is rank 3", op,
135 "[1,2,3];?;?;?;?;?;?;?;?;?;?;?;?;?;?");
136 INFER_ERROR("Shape must be rank 1 but is rank 3", op,
137 "?;?;?;[1,2,3];?;?;?;?;?;?;?;?;?;?;?");
138 INFER_ERROR("Shape must be rank 1 but is rank 3", op,
139 "?;?;?;?;?;?;[1,2,3];?;?;?;?;?;?;?;?");
140 INFER_ERROR("Shape must be rank 1 but is rank 3", op,
141 "?;?;?;?;?;?;?;?;?;[1,2,3];?;?;?;?;?");
142 INFER_ERROR("Shape must be rank 1 but is rank 3", op,
143 "?;?;?;?;?;?;?;?;?;?;?;?;[1,2,3];?;?");
144
145 // last dim of first input is merged with the single dim in other 4 inputs.
146 INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];?;[];[]", "[?,?,?,?];[];[]");
147 INFER_OK(op, "?;[];[];[1];[];[];?;[];[];?;[];[];?;[];[]",
148 "[?,?,?,d3_0];[];[]");
149 INFER_OK(op, "?;[];[];?;[];[];[1];[];[];?;[];[];?;[];[]",
150 "[?,?,?,d6_0];[];[]");
151 INFER_OK(op, "?;[];[];?;[];[];?;[];[];[1];[];[];?;[];[]",
152 "[?,?,?,d9_0];[];[]");
153 INFER_OK(op, "?;[];[];?;[];[];?;[];[];?;[];[];[1];[];[]",
154 "[?,?,?,d12_0];[];[]");
155 INFER_OK(op, "[1,2,3,4];[];[];[4];[];[];[4];[];[];[4];[];[];[4];[];[]",
156 "[d0_0,d0_1,d0_2,d0_3|d3_0|d6_0|d9_0|d12_0];[];[]");
157 }
158
TEST(NNOpsTest,BatchNormWithGlobalNormalizationGrad_ShapeFn)159 TEST(NNOpsTest, BatchNormWithGlobalNormalizationGrad_ShapeFn) {
160 ShapeInferenceTestOp op("BatchNormWithGlobalNormalizationGrad");
161
162 // Test rank errors.
163 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
164 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
165 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
166 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
167 INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op,
168 "?;?;?;?;[1,2,3]");
169
170 // The first output comes from the first and last inputs merged together.
171 // Other inputs are merged with the last dim of that merge result, and that
172 // merged vector dim is the last 4 outputs.
173 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
174 INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
175 INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
176 INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
177 INFER_OK(op, "[1,?,3,?];[?];[?];[?];[?,2,?,4]",
178 "[d0_0,d4_1,d0_2,d4_3];[d4_3];[d4_3];[d4_3];[d4_3]");
179 }
180
TEST(NNOpsTest,FusedBatchNorm_ShapeFn)181 TEST(NNOpsTest, FusedBatchNorm_ShapeFn) {
182 ShapeInferenceTestOp op("FusedBatchNorm");
183
184 auto set_op = [&op](bool is_training, float exponential_avg_factor,
185 string data_format) {
186 TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNorm")
187 .Input(FakeInput(DT_FLOAT))
188 .Input(FakeInput(DT_FLOAT))
189 .Input(FakeInput(DT_FLOAT))
190 .Input(FakeInput(DT_FLOAT))
191 .Input(FakeInput(DT_FLOAT))
192 .Attr("data_format", data_format)
193 .Attr("is_training", is_training)
194 .Attr("exponential_avg_factor", exponential_avg_factor)
195 .Finalize(&op.node_def));
196 };
197
198 set_op(true, 1.0, "NHWC");
199 // Test rank errors.
200 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
201 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
202 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
203 // Channel dim of first input is merged with the single dim in other 4 inputs.
204 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
205 INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
206 INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
207 INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
208 "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
209 "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
210 "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
211
212 set_op(true, 0.5, "NHWC");
213 // Test rank errors.
214 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
215 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
216 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
217 // Channel dim of first input is merged with the single dim in other 4 inputs.
218 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
219 INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
220 INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
221 INFER_OK(op, "[1,2,3,4];[4];[4];?;?",
222 "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0];"
223 "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0];"
224 "[d0_3|d1_0|d2_0];[d0_3|d1_0|d2_0]");
225
226 set_op(true, 1.0, "NCHW");
227 // Test rank errors.
228 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
229 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
230 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
231 // Channel dim of first input is merged with the single dim in other 4 inputs.
232 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
233 INFER_OK(op, "?;[1];?;?;?", "[?,d1_0,?,?];[d1_0];[d1_0];[d1_0];[d1_0]");
234 INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
235 INFER_OK(op, "[1,4,2,3];[4];[4];?;?",
236 "[d0_0,d0_1|d1_0|d2_0,d0_2,d0_3];"
237 "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0];"
238 "[d0_1|d1_0|d2_0];[d0_1|d1_0|d2_0]");
239
240 set_op(false, 1.0, "NHWC");
241 // Test rank errors.
242 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
243 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
244 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
245 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
246 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
247 // Channel dim of first input is merged with the single dim in other 4 inputs.
248 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
249 INFER_OK(op, "?;[1];?;?;?", "[?,?,?,d1_0];[d1_0];[d1_0];[d1_0];[d1_0]");
250 INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[d2_0];[d2_0]");
251 INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[d3_0];[d3_0]");
252 INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[d4_0];[d4_0]");
253 INFER_OK(op, "[1,2,3,4];[4];[4];[4];[4]",
254 "[d0_0,d0_1,d0_2,d0_3|d1_0|d2_0|d3_0|d4_0];"
255 "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0];"
256 "[d0_3|d1_0|d2_0|d3_0|d4_0];[d0_3|d1_0|d2_0|d3_0|d4_0]");
257
258 set_op(false, 1.0, "NCHW");
259 // Test rank errors.
260 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
261 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;[1,2,3];?;?;?");
262 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
263 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
264 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
265 // Channel dim of first input is merged with the single dim in other 4 inputs.
266 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[?];[?]");
267 INFER_OK(op, "?;[1];?;?;?", "[?,d1_0,?,?];[d1_0];[d1_0];[d1_0];[d1_0]");
268 INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[d2_0];[d2_0]");
269 INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[d3_0];[d3_0]");
270 INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[d4_0];[d4_0]");
271 INFER_OK(op, "[1,4,2,3];[4];[4];[4];[4]",
272 "[d0_0,d0_1|d1_0|d2_0|d3_0|d4_0,d0_2,d0_3];"
273 "[d0_1|d1_0|d2_0|d3_0|d4_0];[d0_1|d1_0|d2_0|d3_0|d4_0];"
274 "[d0_1|d1_0|d2_0|d3_0|d4_0];[d0_1|d1_0|d2_0|d3_0|d4_0]");
275 }
276
TEST(NNOpsTest,FusedBatchNormGrad_ShapeFn)277 TEST(NNOpsTest, FusedBatchNormGrad_ShapeFn) {
278 ShapeInferenceTestOp op("FusedBatchNormGrad");
279 auto set_op = [&op](string data_format) {
280 TF_ASSERT_OK(NodeDefBuilder("test", "FusedBatchNormGrad")
281 .Input(FakeInput(DT_FLOAT))
282 .Input(FakeInput(DT_FLOAT))
283 .Input(FakeInput(DT_FLOAT))
284 .Input(FakeInput(DT_FLOAT))
285 .Input(FakeInput(DT_FLOAT))
286 .Attr("data_format", data_format)
287 .Finalize(&op.node_def));
288 };
289
290 set_op("NCHW");
291 // Test rank errors.
292 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
293 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
294 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
295 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
296 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
297 // Channel dim of first input is merged with the single dim in other 4 inputs.
298 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
299 INFER_OK(op, "?;?;[1];?;?", "[?,d2_0,?,?];[d2_0];[d2_0];[0];[0]");
300 INFER_OK(op, "?;?;?;[1];?", "[?,d3_0,?,?];[d3_0];[d3_0];[0];[0]");
301 INFER_OK(op, "?;?;?;?;[1]", "[?,d4_0,?,?];[d4_0];[d4_0];[0];[0]");
302 INFER_OK(op, "[1,4,2,3];[1,4,2,3];[4];[4];[4]",
303 "[d0_0,d0_1|d2_0|d3_0|d4_0,d0_2,d0_3];"
304 "[d0_1|d2_0|d3_0|d4_0];[d0_1|d2_0|d3_0|d4_0];[0];[0]");
305
306 set_op("NHWC");
307 // Test rank errors.
308 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?;?;?");
309 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "?;[1,2,3];?;?;?");
310 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;[1,2,3];?;?");
311 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;[1,2,3];?");
312 INFER_ERROR("Shape must be rank 1 but is rank 3", op, "?;?;?;?;[1,2,3]");
313 // Channel dim of first input is merged with the single dim in other 4 inputs.
314 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?];[?];[?];[0];[0]");
315 INFER_OK(op, "?;?;[1];?;?", "[?,?,?,d2_0];[d2_0];[d2_0];[0];[0]");
316 INFER_OK(op, "?;?;?;[1];?", "[?,?,?,d3_0];[d3_0];[d3_0];[0];[0]");
317 INFER_OK(op, "?;?;?;?;[1]", "[?,?,?,d4_0];[d4_0];[d4_0];[0];[0]");
318 INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4];[4]",
319 "[d0_0,d0_1,d0_2,d0_3|d2_0|d3_0|d4_0];"
320 "[d0_3|d2_0|d3_0|d4_0];[d0_3|d2_0|d3_0|d4_0];[0];[0]");
321 }
322
TEST(NNOpsTest,Conv2DBackpropInput_ShapeFn)323 TEST(NNOpsTest, Conv2DBackpropInput_ShapeFn) {
324 ShapeInferenceTestOp op("Conv2DBackpropInput");
325
326 // Test rank error.
327 INFER_ERROR("input_sizes to contain 4 values or 2 values", op,
328 "[3];[?,?,?,?];[?,?,?,?]");
329 INFER_ERROR("Shape must be rank 4 but is rank 3", op,
330 "[4];[?,?,?,?];[?,?,?]");
331
332 // When input_sizes is a 4D shape and the convolution is grouped, the channel
333 // size of the input grad doesn't always equal the input channel size of the
334 // filter. So, when input_sizes is a 4D shape, the channel size of the input
335 // grad is determined by the content of input_sizes.
336 INFER_OK(op, "[4];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,?]");
337 // When input_sizes is a 2D shape, the channel size of the input grad always
338 // matches the filter shape.
339 INFER_OK(op, "[2];[?,?,2,?];[1,?,?,?]", "[d2_0,?,?,d1_2]");
340 }
341
TEST(NNOpsTest,Conv3DBackpropInput_ShapeFn)342 TEST(NNOpsTest, Conv3DBackpropInput_ShapeFn) {
343 ShapeInferenceTestOp op("Conv3DBackpropInput");
344
345 // Test rank error.
346 INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?");
347
348 // input[1] is transferred to output after asserting its rank.
349 INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
350 INFER_OK(op, "[?,?,?,?,?];?;?", "in0");
351 INFER_OK(op, "[?,2,?,4,?];?;?", "in0");
352 }
353
TEST(NNOpsTest,Conv3DBackpropFilter_ShapeFn)354 TEST(NNOpsTest, Conv3DBackpropFilter_ShapeFn) {
355 ShapeInferenceTestOp op("Conv3DBackpropFilter");
356
357 // Test rank error.
358 INFER_ERROR("Shape must be rank 5 but is rank 3", op, "?;[1,2,3];?");
359
360 // input[1] is transferred to output after asserting its rank.
361 INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
362 INFER_OK(op, "?;[?,?,?,?,?];?", "in1");
363 INFER_OK(op, "?;[?,2,?,4,?];?", "in1");
364 }
365
TEST(NNOpsTest,MaxPool3DGrad_ShapeFn)366 TEST(NNOpsTest, MaxPool3DGrad_ShapeFn) {
367 ShapeInferenceTestOp op("MaxPool3DGrad");
368
369 // Test rank error.
370 INFER_ERROR("Shape must be rank 5 but is rank 3", op, "[1,2,3];?;?");
371
372 // input[0] is transferred to output after asserting its rank.
373 INFER_OK(op, "?;?;?", "[?,?,?,?,?]");
374 INFER_OK(op, "[?,?,?,?,?];?;?", "in0");
375 INFER_OK(op, "[?,2,?,4,?];?;?", "in0");
376 }
377
TEST(NNOpsTest,LRNGrad_ShapeFn)378 TEST(NNOpsTest, LRNGrad_ShapeFn) {
379 ShapeInferenceTestOp op("LRNGrad");
380
381 // LRN Grad is a merge of all three inputs, of rank 4.
382 INFER_OK(op, "[1,?,?,4];[?,2,?,?];[?,?,3,?]", "[d0_0,d1_1,d2_2,d0_3]");
383
384 // Test rank errors.
385 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?");
386 INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;[1,2,3];?");
387 INFER_ERROR("Shapes must be equal rank, but are 4 and 3", op, "?;?;[1,2,3]");
388 }
389
TEST(NNOpsTest,MaxPoolGrad_ShapeFn)390 TEST(NNOpsTest, MaxPoolGrad_ShapeFn) {
391 for (const char* op_name : {"MaxPoolGrad", "MaxPoolGradWithArgmax"}) {
392 ShapeInferenceTestOp op(op_name);
393
394 // Test rank error.
395 INFER_ERROR("Shape must be rank 4 but is rank 3", op, "[1,2,3];?;?");
396
397 // input[0] is transferred to output after asserting its rank.
398 INFER_OK(op, "?;?;?", "[?,?,?,?]");
399 INFER_OK(op, "[?,?,?,?];?;?", "in0");
400 INFER_OK(op, "[?,2,?,4];?;?", "in0");
401 }
402 }
403
TEST(NNOpsTest,Dilation2DBackpropInput_ShapeFn)404 TEST(NNOpsTest, Dilation2DBackpropInput_ShapeFn) {
405 ShapeInferenceTestOp op("Dilation2DBackpropInput");
406
407 // input[0] is transferred to output.
408 INFER_OK(op, "?;?;?", "in0");
409 INFER_OK(op, "?;[?,?,?,?,?];?", "in0");
410 INFER_OK(op, "?;[?,2,?,4,?];?", "in0");
411 }
412
TEST(NNOpsTest,Dilation2DBackpropFilter_ShapeFn)413 TEST(NNOpsTest, Dilation2DBackpropFilter_ShapeFn) {
414 ShapeInferenceTestOp op("Dilation2DBackpropFilter");
415
416 // input[1] is transferred to output.
417 INFER_OK(op, "?;?;?", "in1");
418 INFER_OK(op, "?;[?,?,?,?,?];?", "in1");
419 INFER_OK(op, "?;[?,2,?,4,?];?", "in1");
420 }
421
TEST(NNOpsTest,MergeBothInputs_ShapeFn)422 TEST(NNOpsTest, MergeBothInputs_ShapeFn) {
423 for (const char* op_name : {"ReluGrad", "Relu6Grad", "EluGrad", "SeluGrad",
424 "SoftplusGrad", "SoftsignGrad"}) {
425 ShapeInferenceTestOp op(op_name);
426
427 INFER_OK(op, "?;?", "in0|in1");
428 INFER_OK(op, "?;[1,?,3]", "in1");
429 INFER_OK(op, "[1,?,3];?", "in0");
430 INFER_OK(op, "[1,?];[?,2]", "[d0_0,d1_1]");
431 INFER_ERROR("Dimension 1 in both shapes must be equal, but are 3 and 2", op,
432 "[1,3];[?,2]");
433 }
434 }
435
TEST(NNOpsTest,SoftmaxCrossEntropyWithLogits_ShapeFn)436 TEST(NNOpsTest, SoftmaxCrossEntropyWithLogits_ShapeFn) {
437 ShapeInferenceTestOp op("SoftmaxCrossEntropyWithLogits");
438
439 // Inputs are [batch_size,N] and [batch_size,N], and outputs are [batch_size]
440 // and
441 // [batch_size,N].
442 INFER_OK(op, "?;?", "[?];[?,?]");
443 INFER_OK(op, "[?,?];[?,?]", "[d0_0|d1_0];in0|in1");
444 INFER_OK(op, "[1,2];[?,2]", "[d0_0];in0");
445 INFER_OK(op, "[1,?];[?,2]", "[d0_0];[d0_0,d0_1|d1_1]");
446 INFER_OK(op, "[?,2];[1,2]", "[d1_0];in1");
447
448 INFER_ERROR("Shape must be broadcasted with rank 2", op, "[1,2,3];?");
449 INFER_ERROR("Shape must be broadcasted with rank 2", op, "?;[1,2,3]");
450
451 // Broadcast example
452 // [1,4] and [2,4] are broadcasted to [2,4]
453 INFER_OK(op, "[1,4];[2,4]", "[d1_0];[d1_0,d0_1|d1_1]");
454 // [2,4] and [2,1] are broadcasted to [2,4]
455 INFER_OK(op, "[2,4];[2,1]", "[d0_0];[d0_0|d1_0,d0_1]");
456 // [1,?] and [2,4] are broadcasted to [2,4]
457 INFER_OK(op, "[1,?];[2,4]", "[d1_0];[d1_0,d0_1|d1_1]");
458 // [2,4] and [?,1] are broadcasted to [2,4]
459 INFER_OK(op, "[2,4];[?,1]", "[d0_0];[d0_0|d1_0,d0_1]");
460 }
461
TEST(NNOpsTest,SparseSoftmaxCrossEntropyWithLogits_ShapeFn)462 TEST(NNOpsTest, SparseSoftmaxCrossEntropyWithLogits_ShapeFn) {
463 ShapeInferenceTestOp op("SparseSoftmaxCrossEntropyWithLogits");
464
465 // Inputs are [batch_size,N] and [batch_size], and outputs are [batch_size]
466 // and [batch_size,N].
467 INFER_OK(op, "?;?", "[?];[?,?]");
468 INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]");
469 INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0];[d0_0|d1_0,d0_1]");
470 INFER_OK(op, "[?,2];[1]", "[d1_0];[d1_0,d0_1]");
471
472 INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
473 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
474 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
475 }
476
TEST(NNOpsTest,InTopK_ShapeFn)477 TEST(NNOpsTest, InTopK_ShapeFn) {
478 ShapeInferenceTestOp op("InTopK");
479
480 // Inputs are [batch_size,N] and [batch_size], and output is [batch_size].
481 INFER_OK(op, "?;?", "[?]");
482 INFER_OK(op, "[?,?];[?]", "[d0_0|d1_0]");
483 INFER_OK(op, "[1,2];[1]", "[d0_0|d1_0]");
484 INFER_OK(op, "[?,2];[1]", "[d1_0]");
485
486 INFER_ERROR("Dimensions must be equal, but are 1 and 2", op, "[1,?];[2]");
487 INFER_ERROR("Shape must be rank 2 but is rank 3", op, "[1,2,3];?");
488 INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,2]");
489 }
490
TEST(NNOpsTest,Dilation2DShapeTest)491 TEST(NNOpsTest, Dilation2DShapeTest) {
492 ShapeInferenceTestOp op("Dilation2D");
493 auto set_op = [&op](const std::vector<int32>& strides,
494 const std::vector<int32>& rates, const string& padding) {
495 TF_ASSERT_OK(NodeDefBuilder("test", "Dilation2D")
496 .Input("input", 0, DT_FLOAT)
497 .Input("filter", 0, DT_FLOAT)
498 .Attr("strides", strides)
499 .Attr("rates", rates)
500 .Attr("padding", padding)
501 .Finalize(&op.node_def));
502 };
503
504 // rate rows and cols is 1, so filter_rows and cols are unchanged.
505 // We have a 1x1 filter so the output is still 2x2.
506 set_op({1, 1, 1, 1}, {1, 1, 1, 1}, "VALID");
507 INFER_OK(op, "[1,2,2,2];[1,1,2]", "[d0_0,2,2,d1_2]");
508
509 // rate rows and cols is 2, so filter_rows and cols are changed to
510 // be 2 + (2 - 1) = 3. 7x7 input with 3x3 filter and 1x1 stride
511 // gives a 5x5 output.
512 set_op({1, 1, 1, 1}, {1, 2, 2, 1}, "VALID");
513 INFER_OK(op, "[1,7,7,2];[2,2,2]", "[d0_0,5,5,d1_2]");
514 }
515
TEST(NNOpsTest,FractionalPool_ShapeFn)516 TEST(NNOpsTest, FractionalPool_ShapeFn) {
517 for (const char* op_name : {"FractionalAvgPool", "FractionalMaxPool"}) {
518 ShapeInferenceTestOp op(op_name);
519 auto set_op = [&op, op_name](const std::vector<float>& pooling_ratio) {
520 TF_ASSERT_OK(NodeDefBuilder("test", op_name)
521 .Input("input", 0, DT_FLOAT)
522 .Attr("pooling_ratio", pooling_ratio)
523 .Finalize(&op.node_def));
524 };
525
526 set_op(std::vector<float>{2.0f, 1, 1 / 1.5f, 1 / 2.0f});
527
528 // Rank check.
529 INFER_ERROR("must be rank 4", op, "[?,?,?]");
530
531 // Unknown inputs.
532 INFER_OK(op, "?", "[?,?,?,?];[?];[?]");
533 INFER_OK(op, "[?,?,?,?]", "[?,?,?,?];[?];[?]");
534
535 INFER_OK(op, "[10,20,30,40]", "[5,20,45,80];[20];[45]");
536 INFER_OK(op, "[?,20,30,40]", "[?,20,45,80];[20];[45]");
537 INFER_OK(op, "[10,?,30,40]", "[5,?,45,80];[?];[45]");
538 INFER_OK(op, "[10,20,?,40]", "[5,20,?,80];[20];[?]");
539 INFER_OK(op, "[10,20,30,?]", "[5,20,45,?];[20];[45]");
540
541 // Wrong number of values for pooling_ratio.
542 set_op(std::vector<float>{.5, 1.0, 1.5});
543 INFER_ERROR("pooling_ratio field", op, "?");
544 set_op(std::vector<float>{1, 2, 3, 4, 5});
545 INFER_ERROR("pooling_ratio field", op, "?");
546
547 // Check dim size >= 0.
548 set_op(std::vector<float>{-1, 2, 3, 4});
549 INFER_ERROR("is negative", op, "[1,2,3,4]");
550 }
551 }
552
TEST(NNOpsTest,FractionalMaxPoolGrad)553 TEST(NNOpsTest, FractionalMaxPoolGrad) {
554 ShapeInferenceTestOp op("FractionalMaxPoolGrad");
555
556 // Note that the shape fn only uses input[0] for computation.
557 INFER_ERROR("must be rank 4", op, "[?,?,?];?;?;?;?");
558 INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]");
559 INFER_OK(op, "[?,?,3,4];?;?;?;?", "in0");
560 }
561
TEST(NNOpsTest,FractionalAvgPoolGrad)562 TEST(NNOpsTest, FractionalAvgPoolGrad) {
563 ShapeInferenceTestOp op("FractionalAvgPoolGrad");
564 op.input_tensors.resize(1);
565
566 // With no input shape tensor, returns unknown of rank 4.
567 INFER_OK(op, "?;?;?;?", "[?,?,?,?]");
568
569 // When input tensor is known, its values determine output shape.
570 std::vector<int32> shape{1, 2, 3, 4};
571 Tensor shape_t = test::AsTensor<int32>(shape);
572 op.input_tensors[0] = &shape_t;
573 INFER_OK(op, "[5];?;?;?", "[1,2,3,4]");
574 }
575
576 } // end namespace tensorflow
577