1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/padding.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31 namespace {
32
33 using ::testing::ContainsRegex;
34 using ::testing::HasSubstr;
35
36 class ShapeInferenceTest : public ::testing::Test {
37 protected:
38 // Some handy scalar shapes.
39 const Shape s32_ = ShapeUtil::MakeShape(S32, {});
40 const Shape f16_ = ShapeUtil::MakeShape(F16, {});
41 const Shape f32_ = ShapeUtil::MakeShape(F32, {});
42 const Shape f64_ = ShapeUtil::MakeShape(F64, {});
43 const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
44
45 // Some handy vector and matrix shapes of F32 type.
46 // Suffix: vector_length_, matrix_rows_cols_
47 const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
48 const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
49 const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
50 const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
51 const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
52
53 // Some handy S32 arrays.
54 const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
55 };
56
57 // Subclass for testing InferReduceShape.
58 class ReduceShapeInferenceTest : public ShapeInferenceTest {
59 protected:
60 // Helper that runs reduce shape inference with the input 'arg' and given
61 // dimensions to reduce, and checks the inferred shape is as expected. The
62 // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64> dimensions_to_reduce)63 void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
64 const Shape& arg,
65 absl::Span<const int64> dimensions_to_reduce) {
66 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
67 auto inferred_status = ShapeInference::InferReduceShape(
68 {&arg, &f32_}, dimensions_to_reduce, to_apply);
69 EXPECT_IS_OK(inferred_status.status());
70 EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
71 inferred_status.ValueOrDie()));
72 }
73 };
74
75 // Subclass for testing InferSelectAndScatterShape.
76 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
77 protected:
SelectAndScatterShapeInferenceTest()78 SelectAndScatterShapeInferenceTest() {
79 operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
80 source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
81 WindowDimension dim;
82 dim.set_size(2);
83 dim.set_stride(2);
84 dim.set_padding_low(0);
85 dim.set_padding_high(0);
86 dim.set_window_dilation(1);
87 dim.set_base_dilation(1);
88 *window_.add_dimensions() = dim;
89 *window_.add_dimensions() = dim;
90 init_value_shape_ = ShapeUtil::MakeShape(F32, {});
91 select_program_shape_ = ShapeUtil::MakeProgramShape(
92 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
93 scatter_program_shape_ = ShapeUtil::MakeProgramShape(
94 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
95 }
96
97 Shape operand_shape_;
98 Shape source_shape_;
99 Window window_;
100 Shape init_value_shape_;
101 ProgramShape select_program_shape_;
102 ProgramShape scatter_program_shape_;
103 };
104
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)105 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
106 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
107 auto inferred_status =
108 ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
109 ASSERT_IS_OK(inferred_status.status());
110 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
111 }
112
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)113 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
114 Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
115 auto inferred_status = ShapeInference::InferTernaryOpShape(
116 HloOpcode::kSelect, pred_, tuple, tuple);
117 ASSERT_FALSE(inferred_status.ok());
118 ASSERT_THAT(inferred_status.status().error_message(),
119 HasSubstr("Expected array argument for select"));
120 }
121
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)122 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
123 auto inferred_status = ShapeInference::InferTernaryOpShape(
124 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
125 ASSERT_FALSE(inferred_status.ok());
126 ASSERT_THAT(
127 inferred_status.status().error_message(),
128 HasSubstr("Operands to select and predicate must be the same shape"));
129 }
130
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)131 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
132 auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
133 auto inferred_status = ShapeInference::InferTernaryOpShape(
134 HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
135 ASSERT_IS_OK(inferred_status.status());
136 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
137 }
138
TEST_F(ShapeInferenceTest,SelectBadShapes)139 TEST_F(ShapeInferenceTest, SelectBadShapes) {
140 auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
141 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
142 ASSERT_FALSE(inferred_status_error1.ok());
143 ASSERT_THAT(inferred_status_error1.status().error_message(),
144 HasSubstr("Operands to select must be the same shape"));
145
146 auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
147 HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
148 ASSERT_FALSE(inferred_status_error2.ok());
149 ASSERT_THAT(inferred_status_error2.status().error_message(),
150 HasSubstr("pred operand must have PRED"));
151
152 auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
153 HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
154 matrix_64_48_);
155 ASSERT_FALSE(inferred_status_error3.ok());
156 ASSERT_THAT(
157 inferred_status_error3.status().error_message(),
158 HasSubstr("Operands to select and predicate must be the same shape"));
159
160 // Tuples have a TUPLE element type and cannot be the pred of a select.
161 auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
162 HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
163 ShapeUtil::MakeTupleShape({f32_, f32_}),
164 ShapeUtil::MakeTupleShape({f32_, f32_}));
165 ASSERT_FALSE(inferred_status_error4.ok());
166 ASSERT_THAT(inferred_status_error4.status().error_message(),
167 HasSubstr("Expected array argument for select pred"));
168 }
169
TEST_F(ShapeInferenceTest,ClampAllMatrix)170 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
171 auto inferred_status = ShapeInference::InferTernaryOpShape(
172 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
173 ASSERT_IS_OK(inferred_status.status());
174 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
175 }
176
TEST_F(ShapeInferenceTest,ClampAllScalar)177 TEST_F(ShapeInferenceTest, ClampAllScalar) {
178 auto inferred_status =
179 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
180 ASSERT_IS_OK(inferred_status.status());
181 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
182 }
183
TEST_F(ShapeInferenceTest,ClampMinScalar)184 TEST_F(ShapeInferenceTest, ClampMinScalar) {
185 auto inferred_status = ShapeInference::InferTernaryOpShape(
186 HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
187 ASSERT_FALSE(inferred_status.ok());
188 ASSERT_THAT(inferred_status.status().error_message(),
189 HasSubstr("Clamp with different shapes"));
190 }
191
TEST_F(ShapeInferenceTest,ClampMaxScalar)192 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
193 auto inferred_status = ShapeInference::InferTernaryOpShape(
194 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
195 ASSERT_FALSE(inferred_status.ok());
196 ASSERT_THAT(inferred_status.status().error_message(),
197 HasSubstr("Clamp with different shapes"));
198 }
199
TEST_F(ShapeInferenceTest,ClampOperandScalar)200 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
201 auto inferred_status = ShapeInference::InferTernaryOpShape(
202 HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
203 ASSERT_FALSE(inferred_status.ok());
204 ASSERT_THAT(inferred_status.status().error_message(),
205 HasSubstr("Clamp with different shapes"));
206 }
207
TEST_F(ShapeInferenceTest,ClampMinMatrix)208 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
209 auto inferred_status = ShapeInference::InferTernaryOpShape(
210 HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
211 ASSERT_FALSE(inferred_status.ok());
212 ASSERT_THAT(inferred_status.status().error_message(),
213 HasSubstr("Clamp with different shapes"));
214 }
215
TEST_F(ShapeInferenceTest,ClampMaxMatrix)216 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
217 auto inferred_status = ShapeInference::InferTernaryOpShape(
218 HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
219 ASSERT_FALSE(inferred_status.ok());
220 ASSERT_THAT(inferred_status.status().error_message(),
221 HasSubstr("Clamp with different shapes"));
222 }
223
TEST_F(ShapeInferenceTest,ClampOperandMatrix)224 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
225 auto inferred_status = ShapeInference::InferTernaryOpShape(
226 HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
227 ASSERT_FALSE(inferred_status.ok());
228 ASSERT_THAT(inferred_status.status().error_message(),
229 HasSubstr("Clamp with different shapes"));
230 }
231
TEST_F(ShapeInferenceTest,ClampBadShapes)232 TEST_F(ShapeInferenceTest, ClampBadShapes) {
233 // Type mismatch
234 ASSERT_FALSE(
235 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
236 .ok());
237 ASSERT_FALSE(
238 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
239 .ok());
240 ASSERT_FALSE(
241 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
242 .ok());
243 // Dimension mismatch
244 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
245 HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
246 .ok());
247 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
248 HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
249 .ok());
250 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
251 HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
252 .ok());
253 // Dimension mismatch, where one operand is a scalar
254 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
255 vector_64_, vector_32_, f32_)
256 .ok());
257 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
258 vector_64_, f32_, vector_32_)
259 .ok());
260 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
261 vector_64_, vector_32_)
262 .ok());
263 }
264
TEST_F(ShapeInferenceTest,Complex)265 TEST_F(ShapeInferenceTest, Complex) {
266 auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
267 absl::Span<const int64> bcast) {
268 return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
269 bcast);
270 };
271 // Inputs must be FP.
272 ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
273 ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
274 // Component types must match.
275 ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
276 // Only F32->C64 and F64->C128 supported.
277 ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
278 // Validate correct uses.
279 Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
280 TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
281 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
282 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
283 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
284 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
285 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
286 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
287 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
288
289 Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
290 TF_ASSERT_OK_AND_ASSIGN(result,
291 complex_shape(vector_64_, matrix_32_64_, {1}));
292 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
293 TF_ASSERT_OK_AND_ASSIGN(result,
294 complex_shape(matrix_32_64_, vector_64_, {1}));
295 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
296 TF_ASSERT_OK_AND_ASSIGN(result,
297 complex_shape(matrix_32_64_, matrix_32_64_, {}));
298 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
299 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
300 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
301
302 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
303 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
304 }
305
TEST_F(ShapeInferenceTest,VariadicOpTuplify)306 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
307 StatusOr<Shape> result =
308 ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
309 ASSERT_IS_OK(result.status());
310 ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
311 ShapeUtil::MakeTupleShape({s32_, f32_})));
312 }
313
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)314 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
315 Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
316 Window window;
317 WindowDimension dim;
318 dim.set_size(2);
319 dim.set_stride(2);
320 dim.set_padding_low(0);
321 dim.set_padding_high(0);
322 dim.set_window_dilation(1);
323 dim.set_base_dilation(1);
324 *window.add_dimensions() = dim;
325 *window.add_dimensions() = dim;
326 Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
327 Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
328 Shape float_scalar = ShapeUtil::MakeShape(F32, {});
329 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
330 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
331 auto inferred_status = ShapeInference::InferReduceWindowShape(
332 matrix_shape, init_value_shape, window, to_apply);
333
334 ASSERT_IS_OK(inferred_status.status());
335 Shape inferred = inferred_status.ValueOrDie();
336 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
337 }
338
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)339 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
340 auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
341 operand_shape_, select_program_shape_, window_, source_shape_,
342 init_value_shape_, scatter_program_shape_);
343 ASSERT_IS_OK(inferred_status_ok.status());
344 Shape inferred = inferred_status_ok.ValueOrDie();
345 ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
346 }
347
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)348 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
349 Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
350 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
351 operand_shape_, select_program_shape_, window_, source_shape_fail,
352 init_value_shape_, scatter_program_shape_);
353 ASSERT_FALSE(inferred_status_fail.ok());
354 ASSERT_THAT(inferred_status_fail.status().error_message(),
355 HasSubstr("Source shape does not match"));
356 }
357
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)358 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
359 ProgramShape select_program_shape_fail =
360 ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
361 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
362 operand_shape_, select_program_shape_fail, window_, source_shape_,
363 init_value_shape_, scatter_program_shape_);
364 ASSERT_FALSE(inferred_status_fail.ok());
365 ASSERT_THAT(inferred_status_fail.status().error_message(),
366 HasSubstr("Select function must take 2 parameters"));
367 }
368
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)369 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
370 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
371 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
372 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
373 operand_shape_, select_program_shape_fail, window_, source_shape_,
374 init_value_shape_, scatter_program_shape_);
375 ASSERT_FALSE(inferred_status_fail.ok());
376 ASSERT_THAT(inferred_status_fail.status().error_message(),
377 HasSubstr("Select function must have rank-0 PRED"));
378 }
379
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)380 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
381 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
382 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
383 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
384 operand_shape_, select_program_shape_fail, window_, source_shape_,
385 init_value_shape_, scatter_program_shape_);
386 ASSERT_FALSE(inferred_status_fail.ok());
387 ASSERT_THAT(inferred_status_fail.status().error_message(),
388 HasSubstr("Select function's first parameter"));
389 }
390
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)391 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
392 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
393 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
394 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
395 operand_shape_, select_program_shape_fail, window_, source_shape_,
396 init_value_shape_, scatter_program_shape_);
397 ASSERT_FALSE(inferred_status_fail.ok());
398 ASSERT_THAT(inferred_status_fail.status().error_message(),
399 HasSubstr("Select function's second parameter"));
400 }
401
TEST_F(ShapeInferenceTest,AllGatherStart)402 TEST_F(ShapeInferenceTest, AllGatherStart) {
403 const Shape operand = ShapeUtil::MakeShape(F32, {1, 8, 4});
404 const Shape expected_shape = ShapeUtil::MakeTupleShape(
405 {operand, ShapeUtil::MakeShape(F32, {8, 8, 4})});
406
407 auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape(
408 {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8);
409 EXPECT_TRUE(inferred_ag_shape.ok());
410 EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.ValueOrDie(), expected_shape));
411 }
412
TEST_F(ShapeInferenceTest,AllGatherDone)413 TEST_F(ShapeInferenceTest, AllGatherDone) {
414 const Shape input_shape =
415 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1, 8, 4}),
416 ShapeUtil::MakeShape(F32, {8, 8, 4})});
417 const Shape expected_shape = ShapeUtil::MakeShape(F32, {8, 8, 4});
418
419 auto inferred_ag_done_shape =
420 ShapeInference::InferAllGatherDoneShape(input_shape);
421 EXPECT_TRUE(inferred_ag_done_shape.ok());
422 EXPECT_TRUE(
423 ShapeUtil::Equal(inferred_ag_done_shape.ValueOrDie(), expected_shape));
424 }
425
TEST_F(ShapeInferenceTest,Convolve)426 TEST_F(ShapeInferenceTest, Convolve) {
427 ConvolutionDimensionNumbers dnums;
428
429 // Dimension order: batch, feature, x0, x1
430 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
431 dnums.set_input_batch_dimension(0);
432 dnums.set_output_batch_dimension(0);
433 dnums.set_input_feature_dimension(1);
434 dnums.set_output_feature_dimension(1);
435 dnums.add_input_spatial_dimensions(2);
436 dnums.add_output_spatial_dimensions(2);
437 dnums.add_input_spatial_dimensions(3);
438 dnums.add_output_spatial_dimensions(3);
439
440 // Dimension order: x1, batch, feature, x0
441 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
442 dnums.set_kernel_input_feature_dimension(2);
443 dnums.set_kernel_output_feature_dimension(1);
444 dnums.add_kernel_spatial_dimensions(3);
445 dnums.add_kernel_spatial_dimensions(0);
446
447 Window window;
448 auto dim0 = window.add_dimensions();
449 auto dim1 = window.add_dimensions();
450 dim0->set_size(3);
451 dim0->set_stride(2);
452 dim0->set_padding_low(1);
453 dim0->set_padding_high(1);
454 dim0->set_window_dilation(1);
455 dim0->set_base_dilation(1);
456 dim1->set_size(2);
457 dim1->set_stride(1);
458 dim1->set_padding_low(0);
459 dim1->set_padding_high(0);
460 dim1->set_window_dilation(1);
461 dim1->set_base_dilation(1);
462 auto inferred_status = ShapeInference::InferConvolveShape(
463 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
464 window, dnums, /*preferred_element_type=*/absl::nullopt);
465 ASSERT_IS_OK(inferred_status.status());
466 Shape inferred_shape = inferred_status.ValueOrDie();
467 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
468 inferred_shape));
469 }
470
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)471 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
472 ConvolutionDimensionNumbers dnums;
473
474 // Dimension order: batch, feature, x0, x1
475 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
476 dnums.set_input_batch_dimension(0);
477 dnums.set_output_batch_dimension(0);
478 dnums.set_input_feature_dimension(1);
479 dnums.set_output_feature_dimension(1);
480 dnums.add_input_spatial_dimensions(2);
481 dnums.add_output_spatial_dimensions(2);
482 dnums.add_input_spatial_dimensions(3);
483 dnums.add_output_spatial_dimensions(3);
484
485 // Dimension order: x1, batch, feature, x0
486 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
487 dnums.set_kernel_input_feature_dimension(2);
488 dnums.set_kernel_output_feature_dimension(1);
489 dnums.add_kernel_spatial_dimensions(3);
490 dnums.add_kernel_spatial_dimensions(0);
491
492 Window window;
493 auto dim0 = window.add_dimensions();
494 dim0->set_size(3);
495 dim0->set_stride(3);
496 dim0->set_padding_low(0);
497 dim0->set_padding_high(0);
498 dim0->set_window_dilation(6);
499 dim0->set_base_dilation(1);
500
501 auto dim1 = window.add_dimensions();
502 dim1->set_size(2);
503 dim1->set_stride(1);
504 dim1->set_padding_low(2);
505 dim1->set_padding_high(1);
506 dim1->set_window_dilation(2);
507 dim1->set_base_dilation(1);
508 auto inferred_status = ShapeInference::InferConvolveShape(
509 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
510 window, dnums, /*preferred_element_type=*/absl::nullopt);
511 ASSERT_IS_OK(inferred_status.status());
512 Shape inferred_shape = inferred_status.ValueOrDie();
513 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
514 inferred_shape));
515 }
516
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)517 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
518 ConvolutionDimensionNumbers dnums;
519
520 // Dimension order: batch, feature, x0, x1
521 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
522 dnums.set_input_batch_dimension(0);
523 dnums.set_output_batch_dimension(0);
524 dnums.set_input_feature_dimension(1);
525 dnums.set_output_feature_dimension(1);
526 dnums.add_input_spatial_dimensions(2);
527 dnums.add_output_spatial_dimensions(2);
528 dnums.add_input_spatial_dimensions(3);
529 dnums.add_output_spatial_dimensions(3);
530
531 // Dimension order: x1, batch, feature, x0
532 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
533 dnums.set_kernel_input_feature_dimension(2);
534 dnums.set_kernel_output_feature_dimension(1);
535 dnums.add_kernel_spatial_dimensions(3);
536 dnums.add_kernel_spatial_dimensions(0);
537
538 Window window;
539 auto dim0 = window.add_dimensions();
540 dim0->set_size(4);
541 dim0->set_stride(3);
542 dim0->set_padding_low(0);
543 dim0->set_padding_high(0);
544 dim0->set_window_dilation(1);
545 dim0->set_base_dilation(6);
546
547 auto dim1 = window.add_dimensions();
548 dim1->set_size(2);
549 dim1->set_stride(1);
550 dim1->set_padding_low(2);
551 dim1->set_padding_high(1);
552 dim1->set_window_dilation(1);
553 dim1->set_base_dilation(2);
554 auto inferred_status = ShapeInference::InferConvolveShape(
555 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
556 window, dnums, /*preferred_element_type=*/absl::nullopt);
557 ASSERT_IS_OK(inferred_status.status());
558 Shape inferred_shape = inferred_status.ValueOrDie();
559 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
560 inferred_shape));
561 }
562
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)563 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
564 // Dimension order for this test: batch, feature, x0, x1
565 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
566 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
567
568 ConvolutionDimensionNumbers dnums;
569 dnums.set_input_batch_dimension(3);
570 dnums.set_output_batch_dimension(3);
571 dnums.set_input_feature_dimension(2);
572 dnums.set_output_feature_dimension(2);
573 dnums.add_input_spatial_dimensions(0);
574 dnums.add_output_spatial_dimensions(0);
575 dnums.add_input_spatial_dimensions(1);
576 dnums.add_output_spatial_dimensions(1);
577 dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
578 dnums.set_kernel_output_feature_dimension(3);
579 dnums.add_kernel_spatial_dimensions(0);
580 dnums.add_kernel_spatial_dimensions(1);
581
582 Window window;
583 auto dim0 = window.add_dimensions();
584 auto dim1 = window.add_dimensions();
585 dim0->set_size(2);
586 dim0->set_stride(1);
587 dim0->set_padding_low(0);
588 dim0->set_padding_high(0);
589 dim1->set_size(3);
590 dim1->set_stride(2);
591 dim1->set_padding_low(1);
592 dim1->set_padding_high(1);
593 auto inferred_status = ShapeInference::InferConvolveShape(
594 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
595 window, dnums, /*preferred_element_type=*/absl::nullopt);
596 ASSERT_FALSE(inferred_status.ok());
597 ASSERT_THAT(inferred_status.status().error_message(),
598 HasSubstr("each dimension exactly once"));
599 }
600
TEST_F(ShapeInferenceTest,ConvolveBatchGroupCountUnequalOutputFeature)601 TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
602 ConvolutionDimensionNumbers dnums;
603 dnums.set_input_batch_dimension(0);
604 dnums.set_input_feature_dimension(1);
605 dnums.add_input_spatial_dimensions(2);
606 dnums.add_input_spatial_dimensions(3);
607 dnums.set_kernel_input_feature_dimension(0);
608 dnums.set_kernel_output_feature_dimension(1);
609 dnums.add_kernel_spatial_dimensions(2);
610 dnums.add_kernel_spatial_dimensions(3);
611 dnums.set_output_batch_dimension(0);
612 dnums.set_output_feature_dimension(1);
613 dnums.add_output_spatial_dimensions(2);
614 dnums.add_output_spatial_dimensions(3);
615 Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
616 Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
617 Window window;
618 auto dim0 = window.add_dimensions();
619 auto dim1 = window.add_dimensions();
620 dim0->set_size(4);
621 dim1->set_size(4);
622 dim0->set_padding_low(0);
623 dim0->set_padding_high(2);
624 dim1->set_padding_low(2);
625 dim1->set_padding_high(1);
626 dim0->set_stride(1);
627 dim1->set_stride(1);
628 dim0->set_window_dilation(3);
629 dim1->set_window_dilation(2);
630 auto inferred_status = ShapeInference::InferConvolveShape(
631 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
632 window, dnums, /*preferred_element_type=*/absl::nullopt);
633 ASSERT_FALSE(inferred_status.ok());
634 ASSERT_THAT(inferred_status.status().error_message(),
635 HasSubstr("to be a multiple of batch group count"));
636 }
637
638 struct ConvolveArgs {
639 Shape lhs_shape;
640 Shape rhs_shape;
641 ConvolutionDimensionNumbers dnums;
642 Window window;
643 };
644
MakeConvolveArgs(PrimitiveType lhs_type,PrimitiveType rhs_type)645 ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
646 ConvolveArgs args;
647 ConvolutionDimensionNumbers& dnums = args.dnums;
648
649 // Dimension order: batch, feature, x0, x1
650 args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
651 dnums.set_input_batch_dimension(0);
652 dnums.set_output_batch_dimension(0);
653 dnums.set_input_feature_dimension(1);
654 dnums.set_output_feature_dimension(1);
655 dnums.add_input_spatial_dimensions(2);
656 dnums.add_output_spatial_dimensions(2);
657 dnums.add_input_spatial_dimensions(3);
658 dnums.add_output_spatial_dimensions(3);
659
660 // Dimension order: x1, batch, feature, x0
661 args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
662 dnums.set_kernel_input_feature_dimension(2);
663 dnums.set_kernel_output_feature_dimension(1);
664 dnums.add_kernel_spatial_dimensions(3);
665 dnums.add_kernel_spatial_dimensions(0);
666
667 auto dim0 = args.window.add_dimensions();
668 auto dim1 = args.window.add_dimensions();
669 dim0->set_size(3);
670 dim0->set_stride(2);
671 dim0->set_padding_low(1);
672 dim0->set_padding_high(1);
673 dim0->set_window_dilation(1);
674 dim0->set_base_dilation(1);
675 dim1->set_size(2);
676 dim1->set_stride(1);
677 dim1->set_padding_low(0);
678 dim1->set_padding_high(0);
679 dim1->set_window_dilation(1);
680 dim1->set_base_dilation(1);
681 return args;
682 }
683
TEST_F(ShapeInferenceTest,ConvolveWithBF16_F16)684 TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
685 ConvolveArgs args = MakeConvolveArgs(BF16, F16);
686 TF_ASSERT_OK_AND_ASSIGN(
687 Shape inferred_shape,
688 ShapeInference::InferConvolveShape(
689 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
690 /*batch_group_count=*/1, args.window, args.dnums,
691 /*preferred_element_type=*/absl::nullopt))
692 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
693 inferred_shape));
694 }
695
TEST_F(ShapeInferenceTest,ConvolveWithF16_BF16)696 TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
697 ConvolveArgs args = MakeConvolveArgs(F16, BF16);
698 TF_ASSERT_OK_AND_ASSIGN(
699 Shape inferred_shape,
700 ShapeInference::InferConvolveShape(
701 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
702 /*batch_group_count=*/1, args.window, args.dnums,
703 /*preferred_element_type=*/absl::nullopt))
704 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
705 inferred_shape));
706 }
707
TEST_F(ShapeInferenceTest,ConvolveWithS32_U32)708 TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
709 ConvolveArgs args = MakeConvolveArgs(S32, U32);
710 TF_ASSERT_OK_AND_ASSIGN(
711 Shape inferred_shape,
712 ShapeInference::InferConvolveShape(
713 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
714 /*batch_group_count=*/1, args.window, args.dnums,
715 /*preferred_element_type=*/absl::nullopt))
716 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
717 inferred_shape));
718 }
719
TEST_F(ShapeInferenceTest,ConvolveWithU32_S32)720 TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
721 ConvolveArgs args = MakeConvolveArgs(U32, S32);
722 TF_ASSERT_OK_AND_ASSIGN(
723 Shape inferred_shape,
724 ShapeInference::InferConvolveShape(
725 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
726 /*batch_group_count=*/1, args.window, args.dnums,
727 /*preferred_element_type=*/absl::nullopt))
728 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
729 inferred_shape));
730 }
731
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementType)732 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
733 ConvolveArgs args = MakeConvolveArgs(S8, S16);
734 TF_ASSERT_OK_AND_ASSIGN(
735 Shape inferred_shape,
736 ShapeInference::InferConvolveShape(
737 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
738 /*batch_group_count=*/1, args.window, args.dnums,
739 /*preferred_element_type=*/S16))
740 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
741 inferred_shape));
742 }
743
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeSameAsInferredType)744 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
745 ConvolveArgs args = MakeConvolveArgs(S8, S16);
746 TF_ASSERT_OK_AND_ASSIGN(
747 Shape inferred_shape,
748 ShapeInference::InferConvolveShape(
749 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
750 /*batch_group_count=*/1, args.window, args.dnums,
751 /*preferred_element_type=*/S32))
752 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
753 inferred_shape));
754 }
755
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithNarrowerPreferredElementType)756 TEST_F(ShapeInferenceTest,
757 FloatingPointConvolveWithNarrowerPreferredElementType) {
758 ConvolveArgs args = MakeConvolveArgs(F32, F32);
759 TF_ASSERT_OK_AND_ASSIGN(
760 Shape inferred_shape,
761 ShapeInference::InferConvolveShape(
762 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
763 /*batch_group_count=*/1, args.window, args.dnums,
764 /*preferred_element_type=*/BF16))
765 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
766 inferred_shape));
767 }
768
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithIntegralPreferredElementType)769 TEST_F(ShapeInferenceTest,
770 FloatingPointConvolveWithIntegralPreferredElementType) {
771 ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
772 TF_ASSERT_OK_AND_ASSIGN(
773 Shape inferred_shape,
774 ShapeInference::InferConvolveShape(
775 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
776 /*batch_group_count=*/1, args.window, args.dnums,
777 /*preferred_element_type=*/S32));
778 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
779 inferred_shape));
780 }
781
TEST_F(ShapeInferenceTest,IntegralConvolveWithFloatingPointPreferredElementType)782 TEST_F(ShapeInferenceTest,
783 IntegralConvolveWithFloatingPointPreferredElementType) {
784 ConvolveArgs args = MakeConvolveArgs(S8, S16);
785 TF_ASSERT_OK_AND_ASSIGN(
786 Shape inferred_shape,
787 ShapeInference::InferConvolveShape(
788 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
789 /*batch_group_count=*/1, args.window, args.dnums,
790 /*preferred_element_type=*/F32));
791 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
792 inferred_shape));
793 }
794
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeWithDifferentSignedness)795 TEST_F(ShapeInferenceTest,
796 ConvolveWithPreferredElementTypeWithDifferentSignedness) {
797 ConvolveArgs args = MakeConvolveArgs(S8, S16);
798 TF_ASSERT_OK_AND_ASSIGN(
799 Shape inferred_shape,
800 ShapeInference::InferConvolveShape(
801 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
802 /*batch_group_count=*/1, args.window, args.dnums,
803 /*preferred_element_type=*/U32));
804 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {10, 12, 2, 3}),
805 inferred_shape));
806 }
807
TEST_F(ShapeInferenceTest,ConvolveWithNarrowerPreferredElementType)808 TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
809 ConvolveArgs args = MakeConvolveArgs(S8, S16);
810 auto inferred_status =
811 ShapeInference::InferConvolveShape(
812 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
813 /*batch_group_count=*/1, args.window, args.dnums,
814 /*preferred_element_type=*/S8)
815 .status();
816 ASSERT_FALSE(inferred_status.ok());
817 ASSERT_THAT(inferred_status.error_message(),
818 HasSubstr("must not be narrower than the original type"));
819 }
820
821 namespace fft {
822
823 static const char* unsupported_rank = "only supports ranks 1-3";
824 static const char* invalid_rank = "requires input of at least same rank";
825 static const char* requires_complex_input = "requires complex input type";
826 static const char* requires_f32_input = "requires F32 or F64 input type";
827 static const char* dimensions_match = "innermost dimensions match fft_length";
828 static const char* innermost_dimension_matches =
829 "innermost dimension matches fft_length/2+1";
830
Pass(const Shape & shape,FftType type,absl::Span<const int64> length,const Shape & expected_shape)831 static void Pass(const Shape& shape, FftType type,
832 absl::Span<const int64> length, const Shape& expected_shape) {
833 auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
834 ASSERT_IS_OK(inferred_status.status());
835 Shape inferred_shape = inferred_status.ValueOrDie();
836 ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
837 }
838
Fail(const Shape & shape,FftType type,absl::Span<const int64> length,absl::string_view message)839 static void Fail(const Shape& shape, FftType type,
840 absl::Span<const int64> length, absl::string_view message) {
841 auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
842 ASSERT_FALSE(inferred_status.ok());
843 ASSERT_THAT(inferred_status.status().error_message(),
844 HasSubstr(std::string(message)));
845 }
846
847 } // namespace fft
848
TEST_F(ShapeInferenceTest,InferFftShapeTestFftRanks)849 TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
850 FftType type = FftType::FFT;
851 Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
852 fft::Fail(shape, type, {}, fft::unsupported_rank);
853 fft::Pass(shape, type, {8}, shape);
854 fft::Pass(shape, type, {16, 8}, shape);
855 fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
856 fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
857 }
858
TEST_F(ShapeInferenceTest,InferFftShapeTestFftTypes)859 TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
860 FftType type = FftType::FFT;
861 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
862 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
863 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
864 fft::Pass(shape_c128, type, {16, 8}, shape_c128);
865 }
866
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftRanks)867 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
868 FftType type = FftType::IFFT;
869 Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
870 fft::Fail(shape, type, {}, fft::unsupported_rank);
871 fft::Pass(shape, type, {8}, shape);
872 fft::Pass(shape, type, {16, 8}, shape);
873 fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
874 fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
875 }
876
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftTypes)877 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
878 FftType type = FftType::IFFT;
879 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
880 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
881 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
882 fft::Pass(shape_c128, type, {16, 8}, shape_c128);
883 }
884
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftRanks)885 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
886 FftType type = FftType::RFFT;
887 Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
888 Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
889 fft::Fail(shape_in, type, {}, fft::unsupported_rank);
890 fft::Pass(shape_in, type, {8}, shape_out);
891 fft::Pass(shape_in, type, {16, 8}, shape_out);
892 fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
893 fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
894 }
895
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftDimensions)896 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
897 FftType type = FftType::RFFT;
898 Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
899 fft::Fail(shape, type, {4}, fft::dimensions_match);
900 fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
901 fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
902 fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
903
904 Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
905 Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
906 fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
907 fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
908
909 Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
910 Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
911 Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
912 fft::Pass(even_shape_in, type, {16, 8}, shape_out);
913 fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
914 }
915
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftTypes)916 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
917 FftType type = FftType::RFFT;
918 Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
919 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
920 fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
921 fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
922 }
923
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftRanks)924 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
925 FftType type = FftType::IRFFT;
926 Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
927 Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
928 fft::Fail(shape_in, type, {}, fft::unsupported_rank);
929 fft::Pass(shape_in, type, {8}, shape_out);
930 fft::Pass(shape_in, type, {16, 8}, shape_out);
931 fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
932 fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
933 }
934
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftDimensions)935 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
936 FftType type = FftType::IRFFT;
937 Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
938 fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
939 fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
940 fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
941 fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
942
943 Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
944 Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
945 fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
946 fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
947
948 Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
949 Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
950 fft::Pass(shape, type, {16, 8}, even_shape_out);
951 fft::Pass(shape, type, {16, 9}, odd_shape_out);
952 }
953
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftTypes)954 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
955 FftType type = FftType::IRFFT;
956 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
957 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
958 Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
959 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
960 fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
961 }
962
TEST_F(ShapeInferenceTest,MapThatChangesElementType)963 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
964 Shape arg = ShapeUtil::MakeShape(F32, {20});
965 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
966 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
967 EXPECT_IS_OK(inferred_status.status());
968 Shape expected = ShapeUtil::MakeShape(S32, {20});
969 EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
970 }
971
TEST_F(ShapeInferenceTest,Map)972 TEST_F(ShapeInferenceTest, Map) {
973 auto inferred_status_r1f32 = ShapeInference::InferMapShape(
974 {&vector_32_, &vector_32_},
975 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
976 EXPECT_IS_OK(inferred_status_r1f32.status());
977 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
978
979 // It's OK to provide a single argument, as long as the applied arity matches
980 // (this degenerates to a Map).
981 auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
982 {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
983 EXPECT_IS_OK(inferred_status_r1f32_one.status());
984 EXPECT_TRUE(
985 ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
986
987 auto inferred_status_r2s32 = ShapeInference::InferMapShape(
988 {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
989 ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
990 EXPECT_IS_OK(inferred_status_r2s32.status());
991 EXPECT_TRUE(
992 ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
993
994 auto no_args_error = ShapeInference::InferMapShape(
995 {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
996 ASSERT_FALSE(no_args_error.ok());
997 ASSERT_THAT(no_args_error.status().error_message(),
998 HasSubstr("expects at least one argument"));
999
1000 auto args_diff_shapes_error = ShapeInference::InferMapShape(
1001 {&vector_32_, &vector_64_},
1002 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1003 ASSERT_FALSE(args_diff_shapes_error.ok());
1004 ASSERT_THAT(args_diff_shapes_error.status().error_message(),
1005 HasSubstr("requires all operands to have the same shape"));
1006
1007 auto arity_error = ShapeInference::InferMapShape(
1008 {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
1009 {0});
1010 ASSERT_FALSE(arity_error.ok());
1011 ASSERT_THAT(arity_error.status().error_message(),
1012 HasSubstr("function arity must match"));
1013
1014 auto output_shape_error = ShapeInference::InferMapShape(
1015 {&vector_32_, &vector_32_},
1016 ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
1017 ASSERT_FALSE(output_shape_error.ok());
1018 ASSERT_THAT(output_shape_error.status().error_message(),
1019 HasSubstr("result has to be a scalar"));
1020
1021 auto param_shape_error = ShapeInference::InferMapShape(
1022 {&vector_32_, &vector_32_},
1023 ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
1024 ASSERT_FALSE(param_shape_error.ok());
1025 ASSERT_THAT(param_shape_error.status().error_message(),
1026 HasSubstr("parameter has to be a scalar"));
1027
1028 auto param_element_type_error = ShapeInference::InferMapShape(
1029 {&vector_32_, &vector_32_},
1030 ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
1031 ASSERT_FALSE(param_element_type_error.ok());
1032 ASSERT_THAT(param_element_type_error.status().error_message(),
1033 HasSubstr("parameter type has to match argument"));
1034
1035 Shape arg = ShapeUtil::MakeShape(F32, {20});
1036 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
1037 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
1038 EXPECT_IS_OK(inferred_status.status());
1039 EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
1040
1041 auto inferred_status_error1 = ShapeInference::InferMapShape(
1042 {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1043 ASSERT_FALSE(inferred_status_error1.ok());
1044 ASSERT_THAT(inferred_status_error1.status().error_message(),
1045 HasSubstr("arity must match number of arguments"));
1046
1047 auto inferred_status_error2 = ShapeInference::InferMapShape(
1048 {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
1049 ASSERT_FALSE(inferred_status_error2.ok());
1050 ASSERT_THAT(inferred_status_error2.status().error_message(),
1051 HasSubstr("has to be a scalar"));
1052
1053 auto inferred_status_error3 = ShapeInference::InferMapShape(
1054 {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
1055 ASSERT_FALSE(inferred_status_error3.ok());
1056 ASSERT_THAT(inferred_status_error3.status().error_message(),
1057 HasSubstr("has to be a scalar"));
1058
1059 auto inferred_status_error5 = ShapeInference::InferMapShape(
1060 {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
1061 ASSERT_FALSE(inferred_status_error5.ok());
1062 ASSERT_THAT(inferred_status_error5.status().error_message(),
1063 HasSubstr("parameter type has to match argument"));
1064 }
1065
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)1066 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
1067 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
1068 /*dimensions_to_reduce=*/{0});
1069 }
1070
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)1071 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
1072 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
1073 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1074 /*dimensions_to_reduce=*/{0});
1075 }
1076
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)1077 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
1078 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
1079 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1080 /*dimensions_to_reduce=*/{1});
1081 }
1082
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)1083 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
1084 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
1085 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1086 /*dimensions_to_reduce=*/{0, 1});
1087 }
1088
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)1089 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
1090 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
1091 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1092 /*dimensions_to_reduce=*/{1, 2});
1093 }
1094
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)1095 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
1096 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1097 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1098 /*dimensions_to_reduce=*/{0, 2});
1099
1100 // Check that the order of dimensions_to_reduce doesn't matter.
1101 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1102 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1103 /*dimensions_to_reduce=*/{2, 0});
1104 }
1105
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)1106 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
1107 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
1108 /*dimensions_to_reduce=*/{0, 1, 2});
1109 }
1110
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)1111 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
1112 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1113 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1114 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1115 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1116 auto inferred_status = ShapeInference::InferReduceShape(
1117 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1118 EXPECT_IS_OK(inferred_status.status());
1119 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
1120 inferred_status.ValueOrDie()));
1121 }
1122
TEST_F(ReduceShapeInferenceTest,ReduceWindowMultiOutput)1123 TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
1124 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1125 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1126 std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1127 std::vector<const Shape*> inits = {&f32_, &s32_};
1128 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1129 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1130 std::vector<int64> window_dimensions = {1, 2, 4};
1131 std::vector<int64> window_strides = {1, 1, 1};
1132 std::vector<std::pair<int64, int64>> padding_values =
1133 MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1134 window_strides, Padding::kValid);
1135 TF_ASSERT_OK_AND_ASSIGN(
1136 Window window,
1137 ShapeInference::InferWindowFromDimensions(
1138 window_dimensions, window_strides, padding_values, {}, {}));
1139 auto inferred_status = ShapeInference::InferReduceWindowShape(
1140 absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1141 VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
1142 EXPECT_IS_OK(inferred_status.status());
1143 EXPECT_TRUE(ShapeUtil::Equal(
1144 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
1145 ShapeUtil::MakeShape(S32, {5, 2, 0})}),
1146 inferred_status.ValueOrDie()));
1147 }
1148
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)1149 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
1150 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1151 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1152 ProgramShape to_apply =
1153 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
1154 ShapeUtil::MakeTupleShape({f32_, s32_}));
1155 auto inferred_status = ShapeInference::InferReduceShape(
1156 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1157 EXPECT_FALSE(inferred_status.ok());
1158 EXPECT_THAT(inferred_status.status().error_message(),
1159 HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
1160 }
1161
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)1162 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
1163 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1164 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1165 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1166 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1167 auto inferred_status = ShapeInference::InferReduceShape(
1168 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1169 EXPECT_FALSE(inferred_status.ok());
1170 EXPECT_THAT(
1171 inferred_status.status().error_message(),
1172 HasSubstr(
1173 "parameter shape differs from the result shape: s32[] vs f32[]"));
1174 }
1175
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)1176 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
1177 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1178 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1179 auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
1180 EXPECT_FALSE(inferred_status.ok());
1181 EXPECT_THAT(inferred_status.status().error_message(),
1182 HasSubstr("must have at least 2 arguments, has 0"));
1183 }
1184
TEST_F(ReduceShapeInferenceTest,ErrorBadReduceWindowInput)1185 TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
1186 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1187 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1188 std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1189 std::vector<const Shape*> inits = {&f32_, &s32_};
1190 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1191 {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1192 std::vector<int64> window_dimensions = {1, 2, 4};
1193 std::vector<int64> window_strides = {1, 1, 1};
1194 std::vector<std::pair<int64, int64>> padding_values =
1195 MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1196 window_strides, Padding::kValid);
1197 TF_ASSERT_OK_AND_ASSIGN(
1198 Window window,
1199 ShapeInference::InferWindowFromDimensions(
1200 window_dimensions, window_strides, padding_values, {}, {}));
1201 auto inferred_status = ShapeInference::InferReduceWindowShape(
1202 absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1203 EXPECT_FALSE(inferred_status.status().ok());
1204 EXPECT_THAT(inferred_status.status().error_message(),
1205 HasSubstr("f32[] vs s32[]"));
1206 }
1207
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)1208 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
1209 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1210 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1211 ProgramShape to_apply =
1212 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
1213 auto inferred_status = ShapeInference::InferReduceShape(
1214 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1215 EXPECT_FALSE(inferred_status.ok());
1216 EXPECT_THAT(
1217 inferred_status.status().error_message(),
1218 HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
1219 }
1220
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)1221 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
1222 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1223 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1224 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1225 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
1226 auto inferred_status = ShapeInference::InferReduceShape(
1227 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1228 EXPECT_FALSE(inferred_status.ok());
1229 EXPECT_THAT(
1230 inferred_status.status().error_message(),
1231 HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
1232 }
1233
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)1234 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
1235 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1236 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1237 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1238 {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
1239 auto inferred_status = ShapeInference::InferReduceShape(
1240 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1241 EXPECT_FALSE(inferred_status.ok());
1242 EXPECT_THAT(inferred_status.status().error_message(),
1243 HasSubstr("accumulator shape at index 0 differs from the "
1244 "init_value shape: s32[] vs f32[]"));
1245 }
1246
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)1247 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
1248 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1249 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1250 auto inferred_status = ShapeInference::InferReduceShape(
1251 {&arg_shape, &f32_},
1252 /*dimensions_to_reduce=*/{3, 4}, to_apply);
1253 EXPECT_FALSE(inferred_status.ok());
1254 EXPECT_THAT(inferred_status.status().error_message(),
1255 HasSubstr("out-of-bounds dimension"));
1256 }
1257
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)1258 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
1259 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
1260 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1261 auto inferred_status =
1262 ShapeInference::InferReduceShape({&arg_shape, &f32_},
1263 /*dimensions_to_reduce=*/{0}, to_apply);
1264 EXPECT_FALSE(inferred_status.ok());
1265 EXPECT_THAT(inferred_status.status().error_message(),
1266 HasSubstr("take 2 parameters"));
1267 }
1268
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)1269 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
1270 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
1271 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1272 auto inferred_status =
1273 ShapeInference::InferReduceShape({&arg_shape, &f32_},
1274 /*dimensions_to_reduce=*/{0}, to_apply);
1275 EXPECT_FALSE(inferred_status.ok());
1276 EXPECT_THAT(inferred_status.status().error_message(),
1277 HasSubstr("0-th parameter shape differs"));
1278 }
1279
TEST_F(ReduceShapeInferenceTest,ReduceWithRepeatedReduceDimension)1280 TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
1281 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1282 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1283 auto inferred_status = ShapeInference::InferReduceShape(
1284 {&arg_shape, &f32_},
1285 /*dimensions_to_reduce=*/{0, 0}, to_apply);
1286 EXPECT_FALSE(inferred_status.ok());
1287 EXPECT_THAT(inferred_status.status().error_message(),
1288 HasSubstr("Duplicate reduction dimension: 0"));
1289 }
1290
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)1291 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
1292 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1293 auto inferred_status =
1294 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
1295 ASSERT_IS_OK(inferred_status.status());
1296 Shape inferred = inferred_status.ValueOrDie();
1297 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
1298 }
1299
TEST_F(ShapeInferenceTest,InferSliceWithDynamicDimensions)1300 TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
1301 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
1302 auto inferred_status =
1303 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
1304 ASSERT_IS_OK(inferred_status.status());
1305 Shape inferred = inferred_status.ValueOrDie();
1306 ASSERT_TRUE(ShapeUtil::Equal(
1307 ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
1308 }
1309
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)1310 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
1311 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1312 auto inferred_status =
1313 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
1314 ASSERT_IS_OK(inferred_status.status());
1315 Shape inferred = inferred_status.ValueOrDie();
1316 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
1317 }
1318
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)1319 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
1320 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1321 auto inferred_status =
1322 ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
1323 ASSERT_IS_OK(inferred_status.status());
1324 Shape inferred = inferred_status.ValueOrDie();
1325 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
1326 }
1327
TEST_F(ShapeInferenceTest,InferInvalidStride)1328 TEST_F(ShapeInferenceTest, InferInvalidStride) {
1329 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1330 auto inferred_status =
1331 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
1332 ASSERT_FALSE(inferred_status.ok());
1333 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1334 inferred_status.status().code());
1335 }
1336
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)1337 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
1338 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1339 auto inferred_status =
1340 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
1341 ASSERT_FALSE(inferred_status.ok());
1342 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1343 inferred_status.status().code());
1344 }
1345
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)1346 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
1347 Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
1348 auto inferred_status =
1349 ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
1350 ASSERT_TRUE(inferred_status.ok());
1351 Shape inferred = inferred_status.ValueOrDie();
1352 ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
1353 }
1354
TEST_F(ShapeInferenceTest,InferConstIndexShape)1355 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
1356 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1357 auto inferred0_status =
1358 ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
1359 auto inferred1_status =
1360 ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
1361 ASSERT_IS_OK(inferred0_status.status());
1362 ASSERT_IS_OK(inferred1_status.status());
1363 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
1364 ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
1365 }
1366
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)1367 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
1368 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1369 auto inferredNegative_status =
1370 ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
1371 auto inferred2_status =
1372 ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
1373 ASSERT_FALSE(inferredNegative_status.ok());
1374 ASSERT_FALSE(inferred2_status.ok());
1375 EXPECT_THAT(inferredNegative_status.status().error_message(),
1376 HasSubstr("attempt to index out of tuple bounds"));
1377 EXPECT_THAT(inferred2_status.status().error_message(),
1378 HasSubstr("attempt to index out of tuple bounds"));
1379 }
1380
TEST_F(ShapeInferenceTest,InferPowShape)1381 TEST_F(ShapeInferenceTest, InferPowShape) {
1382 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1383 auto inferred_status = ShapeInference::InferBinaryOpShape(
1384 HloOpcode::kPower, ten_floats, f32_, {});
1385 ASSERT_IS_OK(inferred_status.status());
1386 ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
1387 }
1388
TEST_F(ShapeInferenceTest,InferCompareShape)1389 TEST_F(ShapeInferenceTest, InferCompareShape) {
1390 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1391 auto inferred_status = ShapeInference::InferBinaryOpShape(
1392 HloOpcode::kCompare, ten_floats, f32_, {});
1393 ASSERT_IS_OK(inferred_status.status());
1394 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
1395 inferred_status.ValueOrDie()));
1396 }
1397
TEST_F(ShapeInferenceTest,InferReshapeDegenerateCombine)1398 TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
1399 // [1, <=1]
1400 // | reshape
1401 // [<=1]
1402 //
1403 // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1404 auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
1405 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
1406 /*inferred_dimension=*/-1);
1407 ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
1408 }
1409
TEST_F(ShapeInferenceTest,InferReshapeSplit)1410 TEST_F(ShapeInferenceTest, InferReshapeSplit) {
1411 // [<=10]
1412 // | reshape
1413 // [1, 10]
1414 //
1415 // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1416 auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
1417 auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
1418 /*inferred_dimension=*/0);
1419 ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
1420 status.ValueOrDie());
1421 }
1422
TEST_F(ShapeInferenceTest,InferReshapeCombine)1423 TEST_F(ShapeInferenceTest, InferReshapeCombine) {
1424 // [6, <=10]
1425 // | reshape
1426 // [<=60]
1427 auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1428 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
1429 /*inferred_dimension=*/-11);
1430 ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
1431 }
1432
TEST_F(ShapeInferenceTest,UnchangedDimension)1433 TEST_F(ShapeInferenceTest, UnchangedDimension) {
1434 // [6, <=10]
1435 // | reshape
1436 // [2, 3, <=10]
1437 auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1438 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
1439 /*inferred_dimension=*/-11);
1440 ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
1441 status.ValueOrDie());
1442 }
1443
TEST_F(ShapeInferenceTest,InferDynamicBroadcast)1444 TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
1445 // CHECK:
1446 // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
1447
1448 auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
1449 auto inferred_status =
1450 ShapeInference::InferBroadcastShape(operand_shape, {15});
1451 ASSERT_IS_OK(inferred_status.status());
1452 Shape inferred = inferred_status.ValueOrDie();
1453 ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
1454 }
1455
TEST_F(ShapeInferenceTest,BroadcastScalar)1456 TEST_F(ShapeInferenceTest, BroadcastScalar) {
1457 for (auto element_type : {F32, U32, S8}) {
1458 const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
1459 { // no-op scalar broadcast
1460 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
1461 ASSERT_IS_OK(status.status());
1462 ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
1463 }
1464 const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
1465 { // scalar -> 1d broadcast
1466 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
1467 ASSERT_IS_OK(status.status());
1468 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1469 }
1470 { // no-op 1d broadcast
1471 auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
1472 ASSERT_IS_OK(status.status());
1473 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1474 }
1475 const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
1476 { // scalar -> 2d broadcast
1477 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
1478 ASSERT_IS_OK(status.status());
1479 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1480 }
1481 { // 1d -> 2d broadcast
1482 auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
1483 ASSERT_IS_OK(status.status());
1484 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1485 }
1486 }
1487 }
1488
1489 // scalar <dot> vector: ok
TEST_F(ShapeInferenceTest,ScalarDotVector)1490 TEST_F(ShapeInferenceTest, ScalarDotVector) {
1491 DotDimensionNumbers dot_dnums;
1492 auto inferred_status = ShapeInference::InferDotOpShape(
1493 f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
1494 EXPECT_TRUE(inferred_status.ok());
1495 EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
1496 }
1497
1498 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)1499 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
1500 DotDimensionNumbers dot_dnums;
1501 dot_dnums.add_lhs_contracting_dimensions(1);
1502 dot_dnums.add_rhs_contracting_dimensions(0);
1503 auto inferred_status = ShapeInference::InferDotOpShape(
1504 ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
1505 /*preferred_element_type=*/absl::nullopt);
1506 EXPECT_TRUE(inferred_status.ok());
1507 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1508 ShapeUtil::MakeShape(F32, {32, 32, 64})));
1509 }
1510
1511 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)1512 TEST_F(ShapeInferenceTest, VectorDotVector) {
1513 DotDimensionNumbers dot_dnums;
1514 dot_dnums.add_lhs_contracting_dimensions(0);
1515 dot_dnums.add_rhs_contracting_dimensions(0);
1516 auto inferred_status =
1517 ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
1518 /*preferred_element_type=*/absl::nullopt);
1519 ASSERT_IS_OK(inferred_status.status());
1520 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
1521 auto inferred_status_mismatch =
1522 ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
1523 /*preferred_element_type=*/absl::nullopt);
1524 ASSERT_FALSE(inferred_status_mismatch.ok());
1525 }
1526
1527 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1528 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1529 DotDimensionNumbers dot_dnums;
1530 dot_dnums.add_lhs_contracting_dimensions(1);
1531 dot_dnums.add_rhs_contracting_dimensions(0);
1532 auto inferred_status =
1533 ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
1534 /*preferred_element_type=*/absl::nullopt);
1535 ASSERT_IS_OK(inferred_status.status());
1536 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1537 auto inferred_status_mismatch =
1538 ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
1539 /*preferred_element_type=*/absl::nullopt);
1540 ASSERT_FALSE(inferred_status_mismatch.ok());
1541 }
1542
1543 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1544 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1545 DotDimensionNumbers dot_dnums;
1546 dot_dnums.add_lhs_contracting_dimensions(0);
1547 dot_dnums.add_rhs_contracting_dimensions(0);
1548 auto inferred_status =
1549 ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
1550 /*preferred_element_type=*/absl::nullopt);
1551 ASSERT_IS_OK(inferred_status.status());
1552 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1553 auto inferred_status_mismatch =
1554 ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
1555 /*preferred_element_type=*/absl::nullopt);
1556 ASSERT_FALSE(inferred_status_mismatch.ok());
1557 }
1558
1559 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1560 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1561 DotDimensionNumbers dot_dnums;
1562 dot_dnums.add_lhs_contracting_dimensions(1);
1563 dot_dnums.add_rhs_contracting_dimensions(0);
1564 auto inferred_status_match =
1565 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
1566 /*preferred_element_type=*/absl::nullopt);
1567 ASSERT_IS_OK(inferred_status_match.status());
1568 ASSERT_TRUE(
1569 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1570 << "inferred: "
1571 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1572 << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1573 auto inferred_status_mismatch =
1574 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
1575 /*preferred_element_type=*/absl::nullopt);
1576 ASSERT_FALSE(inferred_status_mismatch.ok());
1577 }
1578
1579 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1580 TEST_F(ShapeInferenceTest, DotGeneral) {
1581 Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1582 Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1583 Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1584
1585 DotDimensionNumbers dot_dnums;
1586 dot_dnums.add_lhs_contracting_dimensions(3);
1587 dot_dnums.add_lhs_batch_dimensions(0);
1588 dot_dnums.add_lhs_batch_dimensions(1);
1589
1590 dot_dnums.add_rhs_contracting_dimensions(2);
1591 dot_dnums.add_rhs_batch_dimensions(0);
1592 dot_dnums.add_rhs_batch_dimensions(1);
1593
1594 auto inferred_status_match =
1595 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1596 /*preferred_element_type=*/absl::nullopt);
1597 ASSERT_IS_OK(inferred_status_match.status());
1598 ASSERT_TRUE(
1599 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1600 << "inferred: "
1601 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1602 << " expected: " << ShapeUtil::HumanString(output_shape);
1603 }
1604
1605 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1606 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1607 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1608 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1609
1610 DotDimensionNumbers dot_dnums;
1611 dot_dnums.add_lhs_contracting_dimensions(2);
1612 dot_dnums.add_lhs_contracting_dimensions(3);
1613 dot_dnums.add_lhs_batch_dimensions(0);
1614
1615 dot_dnums.add_rhs_contracting_dimensions(1);
1616 dot_dnums.add_rhs_batch_dimensions(0);
1617
1618 auto inferred_status =
1619 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1620 /*preferred_element_type=*/absl::nullopt);
1621 ASSERT_FALSE(inferred_status.ok());
1622 ASSERT_THAT(inferred_status.status().error_message(),
1623 HasSubstr("Must specify the same number of contracting "
1624 "dimensions for lhs and rhs."));
1625 }
1626
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1627 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1628 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1629 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1630 Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1631
1632 DotDimensionNumbers dot_dnums;
1633 dot_dnums.add_lhs_contracting_dimensions(2);
1634 dot_dnums.add_lhs_contracting_dimensions(3);
1635 dot_dnums.add_lhs_batch_dimensions(0);
1636
1637 dot_dnums.add_rhs_contracting_dimensions(1);
1638 dot_dnums.add_rhs_contracting_dimensions(2);
1639 dot_dnums.add_rhs_batch_dimensions(0);
1640
1641 auto inferred_status =
1642 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1643 /*preferred_element_type=*/absl::nullopt);
1644 EXPECT_TRUE(inferred_status.ok());
1645 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1646 }
1647
TEST_F(ShapeInferenceTest,ErrorSetDimensionSize)1648 TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
1649 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1650 Shape val_shape = ShapeUtil::MakeShape(S32, {1});
1651 auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1652 arg_shape, val_shape, /*dimension=*/0);
1653
1654 EXPECT_FALSE(inferred_status.ok());
1655 EXPECT_THAT(inferred_status.status().error_message(),
1656 HasSubstr("value has to be S32 scalar"));
1657 }
1658
TEST_F(ShapeInferenceTest,ErrorSetDimensionSizeWrongType)1659 TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
1660 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1661 Shape val_shape = ShapeUtil::MakeShape(U32, {});
1662 auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1663 arg_shape, val_shape, /*dimension=*/0);
1664
1665 EXPECT_FALSE(inferred_status.ok());
1666 EXPECT_THAT(inferred_status.status().error_message(),
1667 HasSubstr("value has to be S32 scalar"));
1668 }
1669
1670 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimSizesFails)1671 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
1672 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1673 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1674
1675 DotDimensionNumbers dot_dnums;
1676 dot_dnums.add_lhs_contracting_dimensions(2);
1677 dot_dnums.add_lhs_batch_dimensions(0);
1678
1679 dot_dnums.add_rhs_contracting_dimensions(1);
1680 dot_dnums.add_rhs_batch_dimensions(0);
1681
1682 auto inferred_status =
1683 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1684 /*preferred_element_type=*/absl::nullopt);
1685 ASSERT_FALSE(inferred_status.ok());
1686 ASSERT_THAT(inferred_status.status().error_message(),
1687 HasSubstr("Batch dimension sizes must match"));
1688 }
1689
1690 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimNumbersPasses)1691 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
1692 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1693 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1694
1695 DotDimensionNumbers dot_dnums;
1696 dot_dnums.add_lhs_contracting_dimensions(2);
1697 dot_dnums.add_lhs_batch_dimensions(0);
1698
1699 dot_dnums.add_rhs_contracting_dimensions(0);
1700 dot_dnums.add_rhs_batch_dimensions(1);
1701
1702 auto inferred_status =
1703 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1704 /*preferred_element_type=*/absl::nullopt);
1705 ASSERT_TRUE(inferred_status.ok());
1706 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1707 ShapeUtil::MakeShape(F32, {2, 11, 14})));
1708 }
1709
1710 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1711 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1712 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1713 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1714
1715 DotDimensionNumbers dot_dnums;
1716 dot_dnums.add_lhs_contracting_dimensions(3);
1717 dot_dnums.add_lhs_batch_dimensions(0);
1718
1719 dot_dnums.add_rhs_contracting_dimensions(0);
1720 dot_dnums.add_rhs_batch_dimensions(1);
1721
1722 auto inferred_status =
1723 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1724 /*preferred_element_type=*/absl::nullopt);
1725 ASSERT_FALSE(inferred_status.ok());
1726 ASSERT_THAT(inferred_status.status().error_message(),
1727 HasSubstr("A dimension number is out of range"));
1728 }
1729
1730 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1731 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1732 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1733 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1734
1735 DotDimensionNumbers dot_dnums;
1736 dot_dnums.add_lhs_contracting_dimensions(0);
1737 dot_dnums.add_lhs_batch_dimensions(0);
1738
1739 dot_dnums.add_rhs_contracting_dimensions(0);
1740 dot_dnums.add_rhs_batch_dimensions(1);
1741
1742 auto inferred_status =
1743 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1744 /*preferred_element_type=*/absl::nullopt);
1745 ASSERT_FALSE(inferred_status.ok());
1746 ASSERT_THAT(inferred_status.status().error_message(),
1747 HasSubstr("A dimension number is not unique"));
1748 }
1749
TEST_F(ShapeInferenceTest,DotWithIntegralPreferredElementType)1750 TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
1751 DotDimensionNumbers dot_dnums;
1752 dot_dnums.add_lhs_contracting_dimensions(1);
1753 dot_dnums.add_rhs_contracting_dimensions(0);
1754 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1755 ShapeInference::InferDotOpShape(
1756 ShapeUtil::MakeShape(S8, {32, 32}),
1757 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1758 /*preferred_element_type=*/S32));
1759 EXPECT_TRUE(
1760 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1761 }
1762
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeSameAsInferredType)1763 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
1764 DotDimensionNumbers dot_dnums;
1765 dot_dnums.add_lhs_contracting_dimensions(1);
1766 dot_dnums.add_rhs_contracting_dimensions(0);
1767 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1768 ShapeInference::InferDotOpShape(
1769 ShapeUtil::MakeShape(BF16, {32, 32}),
1770 ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1771 /*preferred_element_type=*/F32));
1772 EXPECT_TRUE(
1773 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1774 }
1775
TEST_F(ShapeInferenceTest,FloatingPointDotWithNarrowerPreferredElementType)1776 TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
1777 DotDimensionNumbers dot_dnums;
1778 dot_dnums.add_lhs_contracting_dimensions(1);
1779 dot_dnums.add_rhs_contracting_dimensions(0);
1780 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1781 ShapeInference::InferDotOpShape(
1782 ShapeUtil::MakeShape(BF16, {32, 32}),
1783 ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1784 /*preferred_element_type=*/BF16));
1785 EXPECT_TRUE(
1786 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
1787 }
1788
TEST_F(ShapeInferenceTest,FloatingPointDotWithIntegralPreferredElementType)1789 TEST_F(ShapeInferenceTest, FloatingPointDotWithIntegralPreferredElementType) {
1790 DotDimensionNumbers dot_dnums;
1791 dot_dnums.add_lhs_contracting_dimensions(1);
1792 dot_dnums.add_rhs_contracting_dimensions(0);
1793 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1794 ShapeInference::InferDotOpShape(
1795 ShapeUtil::MakeShape(BF16, {32, 32}),
1796 ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
1797 /*preferred_element_type=*/S32));
1798 EXPECT_TRUE(
1799 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1800 }
1801
TEST_F(ShapeInferenceTest,IntegralDotWithFloatingPointPreferredElementType)1802 TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
1803 DotDimensionNumbers dot_dnums;
1804 dot_dnums.add_lhs_contracting_dimensions(1);
1805 dot_dnums.add_rhs_contracting_dimensions(0);
1806 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1807 ShapeInference::InferDotOpShape(
1808 ShapeUtil::MakeShape(S8, {32, 32}),
1809 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1810 /*preferred_element_type=*/F32));
1811 EXPECT_TRUE(
1812 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1813 }
1814
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeWithDifferentSignedness)1815 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
1816 DotDimensionNumbers dot_dnums;
1817 dot_dnums.add_lhs_contracting_dimensions(1);
1818 dot_dnums.add_rhs_contracting_dimensions(0);
1819 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1820 ShapeInference::InferDotOpShape(
1821 ShapeUtil::MakeShape(S8, {32, 32}),
1822 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1823 /*preferred_element_type=*/U32));
1824 EXPECT_TRUE(
1825 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(U32, {32, 32})));
1826 }
1827
TEST_F(ShapeInferenceTest,DotWithNarrowerPreferredElementType)1828 TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
1829 DotDimensionNumbers dot_dnums;
1830 dot_dnums.add_lhs_contracting_dimensions(1);
1831 dot_dnums.add_rhs_contracting_dimensions(0);
1832 auto inferred_status = ShapeInference::InferDotOpShape(
1833 ShapeUtil::MakeShape(S8, {32, 32}),
1834 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1835 /*preferred_element_type=*/S8)
1836 .status();
1837 ASSERT_FALSE(inferred_status.ok());
1838 ASSERT_THAT(inferred_status.error_message(),
1839 HasSubstr("must not be narrower than the original type"));
1840 }
1841
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1842 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1843 // Test variations of broadcasting a vector for a binary add with a
1844 // matrix.
1845 const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1846 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1847 const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1848
1849 auto inferred_status_match =
1850 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1851 ASSERT_IS_OK(inferred_status_match.status());
1852 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1853
1854 auto inferred_status_mismatch =
1855 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1856 ASSERT_FALSE(inferred_status_mismatch.ok());
1857
1858 inferred_status_match =
1859 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1860 ASSERT_IS_OK(inferred_status_match.status());
1861 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1862
1863 inferred_status_mismatch =
1864 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1865 ASSERT_FALSE(inferred_status_mismatch.ok());
1866 }
1867
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1868 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1869 // Test variations of broadcasting a matrix for a binary add with a cube.
1870 const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1871 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1872 const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1873 const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1874
1875 auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1876 HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1877 ASSERT_IS_OK(inferred_status_match.status());
1878 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1879
1880 inferred_status_match = ShapeInference::InferBinaryOpShape(
1881 HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1882 ASSERT_IS_OK(inferred_status_match.status());
1883 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1884
1885 inferred_status_match = ShapeInference::InferBinaryOpShape(
1886 HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1887 ASSERT_IS_OK(inferred_status_match.status());
1888 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1889 }
1890
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1891 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1892 // Test various errors with the broadcast argument.
1893 const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1894 const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1895 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1896 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1897 const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1898
1899 // "magical" broadcast rejected
1900 auto inferred_status_error1 =
1901 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1902 ASSERT_FALSE(inferred_status_error1.ok());
1903 ASSERT_THAT(inferred_status_error1.status().error_message(),
1904 HasSubstr("Automatic"));
1905
1906 // broadcast_dimension out of bounds for tensor's rank
1907 auto inferred_status_error2 =
1908 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1909 ASSERT_FALSE(inferred_status_error2.ok());
1910 ASSERT_THAT(inferred_status_error2.status().error_message(),
1911 ContainsRegex("Broadcast dimension number .* too large"));
1912
1913 // broadcast_dimension doesn't match corresponding dimension
1914 auto inferred_status_error3 =
1915 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1916 ASSERT_FALSE(inferred_status_error3.ok());
1917 ASSERT_THAT(inferred_status_error3.status().error_message(),
1918 HasSubstr("Broadcast dimension 0 mismatch"));
1919
1920 // broadcast_dimensions list too long
1921 auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1922 HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1923 ASSERT_FALSE(inferred_status_error4.ok());
1924 ASSERT_THAT(inferred_status_error4.status().error_message(),
1925 HasSubstr("broadcast_dimensions has to match"));
1926
1927 // there's a dimension above the rank of the tensor
1928 auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1929 HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1930 ASSERT_FALSE(inferred_status_error5.ok());
1931 ASSERT_THAT(inferred_status_error5.status().error_message(),
1932 ContainsRegex("dimension number .* too large"));
1933
1934 // broadcasting dimensions don't match in this order
1935 auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1936 HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1937 ASSERT_FALSE(inferred_status_error6.ok());
1938 ASSERT_THAT(inferred_status_error6.status().error_message(),
1939 HasSubstr("dimension 0 mismatch"));
1940
1941 // The following two tests make sure that broadcasting dimensions are listed
1942 // in a proper (strictly increasing) order, even if the lower-rank array
1943 // matches the higher-rank array in many different ways.
1944 auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1945 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1946 ASSERT_FALSE(inferred_status_error7.ok());
1947 ASSERT_THAT(inferred_status_error7.status().error_message(),
1948 HasSubstr("dimensions order is wrong"));
1949
1950 auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1951 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1952 ASSERT_FALSE(inferred_status_error8.ok());
1953 ASSERT_THAT(inferred_status_error8.status().error_message(),
1954 HasSubstr("dimensions order is wrong"));
1955 }
1956
1957 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1958 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1959 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1960 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1961 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1962 auto inferred_status =
1963 ShapeInference::InferWhileShape(cond, body, result_shape);
1964 ASSERT_IS_OK(inferred_status.status());
1965 Shape inferred = inferred_status.ValueOrDie();
1966 ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1967 }
1968
1969 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1970 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1971 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1972 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1973 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1974
1975 auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1976 auto inferred_status_error1 =
1977 ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1978 ASSERT_FALSE(inferred_status_error1.ok());
1979 ASSERT_THAT(inferred_status_error1.status().error_message(),
1980 HasSubstr("Condition must take 1 arguments"));
1981
1982 auto bad_shape_2 =
1983 ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1984 auto inferred_status_error2 =
1985 ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
1986 ASSERT_FALSE(inferred_status_error2.ok());
1987 ASSERT_THAT(inferred_status_error2.status().error_message(),
1988 HasSubstr("Body must take 1 arguments"));
1989
1990 auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
1991 auto inferred_status_error3 =
1992 ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
1993 ASSERT_FALSE(inferred_status_error3.ok());
1994 ASSERT_THAT(inferred_status_error3.status().error_message(),
1995 HasSubstr("Condition must return a boolean"));
1996
1997 auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
1998 auto inferred_status_error4 =
1999 ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
2000 ASSERT_FALSE(inferred_status_error4.ok());
2001 ASSERT_THAT(inferred_status_error4.status().error_message(),
2002 HasSubstr("parameter of condition and body"));
2003 }
2004
2005 // Tests for the concatenate instruction with dynamic shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithDynamicShapes)2006 TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
2007 auto dynamic_shape_1 =
2008 ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
2009 auto dynamic_shape_2 =
2010 ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
2011 auto inferred_status = ShapeInference::InferConcatOpShape(
2012 {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
2013 ASSERT_IS_OK(inferred_status.status());
2014 Shape inferred = inferred_status.ValueOrDie();
2015 ASSERT_TRUE(ShapeUtil::Equal(
2016 ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
2017 }
2018
2019 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)2020 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
2021 auto inferred_status_1 = ShapeInference::InferConcatOpShape(
2022 {&vector_32_, &vector_64_}, /*dimension=*/0);
2023 ASSERT_IS_OK(inferred_status_1.status());
2024 Shape inferred_1 = inferred_status_1.ValueOrDie();
2025 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
2026
2027 auto inferred_status_2 = ShapeInference::InferConcatOpShape(
2028 {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
2029 ASSERT_IS_OK(inferred_status_2.status());
2030 Shape inferred_2 = inferred_status_2.ValueOrDie();
2031 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
2032
2033 auto inferred_status_3 = ShapeInference::InferConcatOpShape(
2034 {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
2035 ASSERT_IS_OK(inferred_status_3.status());
2036 Shape inferred_3 = inferred_status_3.ValueOrDie();
2037 ASSERT_TRUE(
2038 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
2039 }
2040
2041 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)2042 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
2043 auto inferred_status_error1 =
2044 ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
2045 ASSERT_FALSE(inferred_status_error1.ok());
2046 ASSERT_THAT(inferred_status_error1.status().error_message(),
2047 HasSubstr("Concatenate expects at least one argument"));
2048
2049 auto inferred_status_error2 =
2050 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
2051 ASSERT_FALSE(inferred_status_error2.ok());
2052 ASSERT_THAT(inferred_status_error2.status().error_message(),
2053 HasSubstr("dimension out of bounds: -1"));
2054
2055 auto inferred_status_error3 =
2056 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
2057 ASSERT_FALSE(inferred_status_error3.ok());
2058 ASSERT_THAT(inferred_status_error3.status().error_message(),
2059 HasSubstr("dimension out of bounds: 1"));
2060
2061 Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
2062 auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
2063 {&vector_32_, &tuple}, /*dimension=*/0);
2064 ASSERT_FALSE(inferred_status_error4.ok());
2065 ASSERT_THAT(
2066 inferred_status_error4.status().error_message(),
2067 HasSubstr("Expected array argument for operand of concatenation"));
2068
2069 const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
2070 auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
2071 {&vector_32_, &vector_s32}, /*dimension=*/0);
2072 ASSERT_FALSE(inferred_status_error5.ok());
2073 ASSERT_THAT(inferred_status_error5.status().error_message(),
2074 HasSubstr("concatenate arrays with different element types"));
2075
2076 auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
2077 {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
2078 ASSERT_FALSE(inferred_status_error6.ok());
2079 ASSERT_THAT(inferred_status_error6.status().error_message(),
2080 HasSubstr("concatenate arrays that differ in "
2081 "dimensions other than the one being "
2082 "concatenated"));
2083 }
2084
TEST_F(ShapeInferenceTest,Pad)2085 TEST_F(ShapeInferenceTest, Pad) {
2086 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2087 Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
2088 // Padding for dimension 0: {low: 0, high: 2, interior: 3}
2089 // Padding for dimension 1: {low: 1, high: 5, interior: 0}
2090 PaddingConfig padding_config;
2091 auto dimension0 = padding_config.add_dimensions();
2092 dimension0->set_edge_padding_low(0);
2093 dimension0->set_edge_padding_high(2);
2094 dimension0->set_interior_padding(3);
2095 auto dimension1 = padding_config.add_dimensions();
2096 dimension1->set_edge_padding_low(1);
2097 dimension1->set_edge_padding_high(5);
2098 dimension1->set_interior_padding(0);
2099
2100 auto inferred_status = ShapeInference::InferPadShape(
2101 input_shape, padding_value_shape, padding_config);
2102 ASSERT_IS_OK(inferred_status.status());
2103 Shape inferred_shape = inferred_status.ValueOrDie();
2104 ASSERT_TRUE(
2105 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
2106
2107 dimension1->set_edge_padding_low(-20);
2108 dimension1->set_edge_padding_high(-10);
2109 auto negative_dimension_size = ShapeInference::InferPadShape(
2110 input_shape, padding_value_shape, padding_config);
2111 ASSERT_FALSE(negative_dimension_size.ok());
2112 ASSERT_THAT(negative_dimension_size.status().error_message(),
2113 HasSubstr("negative size for dimension 1"));
2114 }
2115
TEST_F(ShapeInferenceTest,Reverse)2116 TEST_F(ShapeInferenceTest, Reverse) {
2117 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2118
2119 auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
2120 ASSERT_IS_OK(inferred_status.status());
2121 Shape inferred_shape = inferred_status.ValueOrDie();
2122 ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
2123 }
2124
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)2125 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
2126 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2127
2128 auto inferred_status_error0 =
2129 ShapeInference::InferReverseShape(input_shape, {0, 2});
2130 ASSERT_FALSE(inferred_status_error0.ok());
2131 ASSERT_THAT(inferred_status_error0.status().error_message(),
2132 HasSubstr("out-of-bounds"));
2133
2134 auto inferred_status_error1 =
2135 ShapeInference::InferReverseShape(input_shape, {0, -1});
2136 ASSERT_FALSE(inferred_status_error1.ok());
2137 ASSERT_THAT(inferred_status_error1.status().error_message(),
2138 HasSubstr("out-of-bounds"));
2139
2140 auto inferred_status_error2 =
2141 ShapeInference::InferReverseShape(input_shape, {0, 0});
2142 ASSERT_FALSE(inferred_status_error2.ok());
2143 ASSERT_THAT(inferred_status_error2.status().error_message(),
2144 HasSubstr("duplicated"));
2145
2146 Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
2147 auto inferred_status_error3 =
2148 ShapeInference::InferReverseShape(tuple_shape, {0});
2149 ASSERT_FALSE(inferred_status_error3.ok());
2150 ASSERT_THAT(inferred_status_error3.status().error_message(),
2151 HasSubstr("Expected array argument"));
2152 }
2153
TEST_F(ShapeInferenceTest,Call)2154 TEST_F(ShapeInferenceTest, Call) {
2155 auto inferred_status0 =
2156 ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
2157 EXPECT_IS_OK(inferred_status0.status());
2158 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2159
2160 auto inferred_status1 = ShapeInference::InferCallShape(
2161 {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
2162 ShapeUtil::MakeProgramShape(
2163 {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
2164 EXPECT_IS_OK(inferred_status1.status());
2165 EXPECT_TRUE(
2166 ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
2167
2168 auto inferred_status_error0 = ShapeInference::InferCallShape(
2169 {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
2170 EXPECT_FALSE(inferred_status_error0.ok());
2171 EXPECT_THAT(inferred_status_error0.status().error_message(),
2172 HasSubstr("arity must match"));
2173
2174 auto inferred_status_error1 = ShapeInference::InferCallShape(
2175 {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
2176 EXPECT_FALSE(inferred_status_error1.ok());
2177 EXPECT_THAT(inferred_status_error1.status().error_message(),
2178 HasSubstr("arity must match"));
2179
2180 auto inferred_status_error2 = ShapeInference::InferCallShape(
2181 {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
2182 EXPECT_FALSE(inferred_status_error2.ok());
2183 EXPECT_THAT(inferred_status_error2.status().error_message(),
2184 HasSubstr("parameter must match argument"));
2185 }
2186
TEST_F(ShapeInferenceTest,Transpose)2187 TEST_F(ShapeInferenceTest, Transpose) {
2188 Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
2189 auto inferred_shape_and_status =
2190 ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
2191 EXPECT_IS_OK(inferred_shape_and_status);
2192 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2193 EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
2194 ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
2195 }
2196
TEST_F(ShapeInferenceTest,Rank1Transpose)2197 TEST_F(ShapeInferenceTest, Rank1Transpose) {
2198 Shape a_shape = ShapeUtil::MakeShape(F32, {5});
2199 auto inferred_shape_and_status =
2200 ShapeInference::InferTransposeShape(a_shape, {0});
2201 EXPECT_IS_OK(inferred_shape_and_status);
2202 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2203 EXPECT_TRUE(
2204 ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
2205 }
2206
TEST_F(ShapeInferenceTest,ConditionalPred)2207 TEST_F(ShapeInferenceTest, ConditionalPred) {
2208 auto inferred_status0 = ShapeInference::InferConditionalShape(
2209 pred_,
2210 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2211 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2212 {vector_32_, vector_64_});
2213 EXPECT_IS_OK(inferred_status0.status());
2214 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2215
2216 auto inferred_status1 = ShapeInference::InferConditionalShape(
2217 pred_,
2218 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2219 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
2220 {matrix_32_48_, vector_32_});
2221 EXPECT_IS_OK(inferred_status1.status());
2222 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2223
2224 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2225 auto inferred_status2 = ShapeInference::InferConditionalShape(
2226 pred_,
2227 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2228 ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2229 {matrix_32_48_, tuple_f32_v32});
2230 EXPECT_IS_OK(inferred_status2.status());
2231 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2232
2233 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2234 f32_,
2235 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2236 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2237 {vector_32_, vector_64_});
2238 EXPECT_FALSE(inferred_status_error0.ok());
2239 EXPECT_THAT(inferred_status_error0.status().error_message(),
2240 HasSubstr("must be bool or int32"));
2241
2242 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2243 pred_,
2244 {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2245 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2246 {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
2247 EXPECT_FALSE(inferred_status_error1.ok());
2248 EXPECT_THAT(inferred_status_error1.status().error_message(),
2249 HasSubstr("branch computation 0 must take 1 argument"));
2250
2251 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2252 pred_,
2253 {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2254 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2255 {vector_32_, vector_64_});
2256 EXPECT_FALSE(inferred_status_error2.ok());
2257 EXPECT_THAT(inferred_status_error2.status().error_message(),
2258 HasSubstr("branch operand 0 must match the shape of the only "
2259 "parameter of branch computation 0"));
2260
2261 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2262 pred_,
2263 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2264 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
2265 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
2266 EXPECT_FALSE(inferred_status_error3.ok());
2267 EXPECT_THAT(inferred_status_error3.status().error_message(),
2268 HasSubstr("branch computation 1 must take 1 argument"));
2269
2270 auto inferred_status_error4 = ShapeInference::InferConditionalShape(
2271 pred_,
2272 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2273 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2274 {vector_32_, vector_64_});
2275 EXPECT_FALSE(inferred_status_error4.ok());
2276 EXPECT_THAT(inferred_status_error4.status().error_message(),
2277 HasSubstr("branch operand 1 must match the shape of the only "
2278 "parameter of branch computation 1"));
2279
2280 auto inferred_status_error5 = ShapeInference::InferConditionalShape(
2281 pred_,
2282 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2283 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2284 {vector_32_, vector_64_});
2285 EXPECT_FALSE(inferred_status_error5.ok());
2286 EXPECT_THAT(inferred_status_error5.status().error_message(),
2287 HasSubstr("the result of branch 0 computation and branch 1 "
2288 "computation must have the same shape"));
2289 }
2290
TEST_F(ShapeInferenceTest,ConditionalIndexed)2291 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
2292 auto r0s32 = ShapeUtil::MakeShape(S32, {});
2293 auto inferred_status0 = ShapeInference::InferConditionalShape(
2294 r0s32,
2295 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2296 ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2297 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2298 {vector_32_, vector_64_, vector_64_});
2299 EXPECT_IS_OK(inferred_status0.status());
2300 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2301
2302 auto inferred_status1 = ShapeInference::InferConditionalShape(
2303 r0s32,
2304 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2305 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
2306 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
2307 {matrix_32_48_, vector_32_, matrix_32_48_});
2308 EXPECT_IS_OK(inferred_status1.status());
2309 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2310
2311 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2312 auto inferred_status2 = ShapeInference::InferConditionalShape(
2313 r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2314 {tuple_f32_v32});
2315 EXPECT_IS_OK(inferred_status2.status());
2316 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2317
2318 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2319 pred_,
2320 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2321 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2322 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2323 {vector_32_, vector_32_, vector_64_});
2324 EXPECT_FALSE(inferred_status_error0.ok());
2325 EXPECT_THAT(inferred_status_error0.status().error_message(),
2326 HasSubstr("2 == branch_computations.size()"));
2327
2328 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2329 r0s32,
2330 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2331 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2332 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2333 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
2334 matrix_32_48_});
2335 EXPECT_FALSE(inferred_status_error1.ok());
2336 EXPECT_THAT(inferred_status_error1.status().error_message(),
2337 HasSubstr("branch computation 1 must take 1 argument"));
2338
2339 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2340 r0s32,
2341 {ShapeUtil::MakeProgramShape({r0s32}, f32_),
2342 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2343 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2344 {r0s32, vector_32_, vector_64_});
2345 EXPECT_FALSE(inferred_status_error2.ok());
2346 EXPECT_THAT(inferred_status_error2.status().error_message(),
2347 HasSubstr("branch operand 2 must match the shape of the only "
2348 "parameter of branch computation 2"));
2349
2350 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2351 r0s32,
2352 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2353 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2354 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2355 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2356 {vector_32_, vector_32_, vector_32_, vector_64_});
2357 EXPECT_FALSE(inferred_status_error3.ok());
2358 EXPECT_THAT(inferred_status_error3.status().error_message(),
2359 HasSubstr("the result of branch 0 computation and branch 3 "
2360 "computation must have the same shape"));
2361
2362 auto inferred_status_error4 =
2363 ShapeInference::InferConditionalShape(r0s32, {}, {});
2364 EXPECT_FALSE(inferred_status_error4.ok());
2365 EXPECT_THAT(inferred_status_error4.status().error_message(),
2366 HasSubstr("!branch_computations.empty()"));
2367 }
2368
TEST_F(ShapeInferenceTest,ConditionalDynamic)2369 TEST_F(ShapeInferenceTest, ConditionalDynamic) {
2370 auto r0s32 = ShapeUtil::MakeShape(S32, {});
2371 auto static_shape = ShapeUtil::MakeShape(S32, {4}, {false});
2372 auto dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true});
2373 auto inferred_status0 = ShapeInference::InferConditionalShape(
2374 r0s32,
2375 {ShapeUtil::MakeProgramShape({vector_32_}, static_shape),
2376 ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape),
2377 ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2378 {vector_32_, vector_64_, vector_64_});
2379 EXPECT_IS_OK(inferred_status0.status());
2380 EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status0.ValueOrDie()));
2381
2382 auto inferred_status1 = ShapeInference::InferConditionalShape(
2383 r0s32,
2384 {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape),
2385 ShapeUtil::MakeProgramShape({vector_64_}, static_shape),
2386 ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2387 {vector_32_, vector_64_, vector_64_});
2388 EXPECT_IS_OK(inferred_status1.status());
2389 EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status1.ValueOrDie()));
2390 }
2391
TEST_F(ShapeInferenceTest,BadSlice)2392 TEST_F(ShapeInferenceTest, BadSlice) {
2393 auto arg = ShapeUtil::MakeShape(F32, {4});
2394 StatusOr<Shape> statusor =
2395 ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
2396 ASSERT_FALSE(statusor.ok());
2397
2398 LOG(INFO) << statusor.status();
2399
2400 EXPECT_THAT(statusor.status().error_message(),
2401 HasSubstr("less than or equal to dimension size"))
2402 << statusor.status();
2403 EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
2404 << statusor.status();
2405 }
2406
TEST_F(ShapeInferenceTest,BadSort)2407 TEST_F(ShapeInferenceTest, BadSort) {
2408 auto keys = ShapeUtil::MakeShape(F32, {4});
2409 auto values = ShapeUtil::MakeShape(F32, {5});
2410 StatusOr<Shape> statusor =
2411 ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
2412 EXPECT_FALSE(statusor.ok());
2413 EXPECT_THAT(statusor.status().error_message(),
2414 HasSubstr("dimensions must match"))
2415 << statusor.status();
2416 }
2417
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)2418 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
2419 auto keys = ShapeUtil::MakeShape(F32, {4});
2420 auto values_good = ShapeUtil::MakeShape(F32, {4});
2421 auto values_bad = ShapeUtil::MakeShape(F32, {5});
2422 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2423 HloOpcode::kSort, {&keys, &values_good, &values_bad});
2424 EXPECT_FALSE(statusor.ok());
2425 EXPECT_THAT(statusor.status().error_message(),
2426 HasSubstr("dimensions must match"))
2427 << statusor.status();
2428 }
2429
TEST_F(ShapeInferenceTest,SortManyValues)2430 TEST_F(ShapeInferenceTest, SortManyValues) {
2431 auto keys = ShapeUtil::MakeShape(F32, {4});
2432 auto values_s32 = ShapeUtil::MakeShape(S32, {4});
2433 auto values_u32 = ShapeUtil::MakeShape(U32, {4});
2434 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2435 HloOpcode::kSort, {&keys, &values_s32, &values_u32});
2436 EXPECT_IS_OK(statusor);
2437 Shape inferred_shape = statusor.ValueOrDie();
2438 EXPECT_TRUE(ShapeUtil::Compatible(
2439 inferred_shape,
2440 ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
2441 }
2442
2443 class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
2444 protected:
2445 const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
2446 const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
2447 const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
2448 const Shape s64_4d_tensor_10_9_8_7_1_ =
2449 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
2450 const Shape s64_4d_tensor_10_9_8_7_5_ =
2451 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
2452 const Shape s64_4d_tensor_5_10_9_7_6_ =
2453 ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
2454 const Shape s64_4d_tensor_10_9_5_7_6_ =
2455 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
2456 const Shape f32_5d_tensor_50_49_48_47_46_ =
2457 ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
2458 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
2459 {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
2460 const ProgramShape to_apply_ =
2461 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
2462 };
2463
2464 // Shape inference tests for Gather.
2465
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGather)2466 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
2467 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2468 ShapeInference::InferGatherShape(
2469 matrix_64_48_, s64_vector_32_,
2470 HloGatherInstruction::MakeGatherDimNumbers(
2471 /*offset_dims=*/{0},
2472 /*collapsed_slice_dims=*/{1},
2473 /*start_index_map=*/{1},
2474 /*index_vector_dim=*/1),
2475 /*slice_sizes=*/{64, 1}));
2476 EXPECT_TRUE(
2477 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
2478 << ShapeUtil::HumanString(gather_shape);
2479 }
2480
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherV2)2481 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
2482 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2483 ShapeInference::InferGatherShape(
2484 matrix_64_48_, s64_vector_32_,
2485 HloGatherInstruction::MakeGatherDimNumbers(
2486 /*offset_dims=*/{1},
2487 /*collapsed_slice_dims=*/{0},
2488 /*start_index_map=*/{0},
2489 /*index_vector_dim=*/1),
2490 /*slice_sizes=*/{1, 48}));
2491 EXPECT_TRUE(
2492 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
2493 << ShapeUtil::HumanString(gather_shape);
2494 }
2495
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherNd)2496 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
2497 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2498 ShapeInference::InferGatherShape(
2499 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2500 HloGatherInstruction::MakeGatherDimNumbers(
2501 /*offset_dims=*/{4},
2502 /*collapsed_slice_dims=*/{0},
2503 /*start_index_map=*/{0},
2504 /*index_vector_dim=*/4),
2505 /*slice_sizes=*/{1, 48}));
2506 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2507 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
2508 << ShapeUtil::HumanString(gather_shape);
2509 }
2510
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowBatchDynamicSlice)2511 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
2512 TF_ASSERT_OK_AND_ASSIGN(
2513 Shape gather_shape,
2514 ShapeInference::InferGatherShape(
2515 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2516 HloGatherInstruction::MakeGatherDimNumbers(
2517 /*offset_dims=*/{4, 5, 6, 7, 8},
2518 /*collapsed_slice_dims=*/{},
2519 /*start_index_map=*/{0, 1, 2, 3, 4},
2520 /*index_vector_dim=*/4),
2521 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2522 EXPECT_TRUE(ShapeUtil::Equal(
2523 gather_shape,
2524 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
2525 << ShapeUtil::HumanString(gather_shape);
2526 }
2527
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherEntireDimension)2528 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherEntireDimension) {
2529 TF_ASSERT_OK_AND_ASSIGN(
2530 Shape gather_shape,
2531 ShapeInference::InferGatherShape(
2532 ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
2533 ShapeUtil::MakeShape(S64, {}),
2534 HloGatherInstruction::MakeGatherDimNumbers(
2535 /*offset_dims=*/{0, 1},
2536 /*collapsed_slice_dims=*/{0},
2537 /*start_index_map=*/{0},
2538 /*index_vector_dim=*/0),
2539 /*slice_sizes=*/{1, 2, 1}));
2540 EXPECT_TRUE(ShapeUtil::Equal(
2541 gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
2542 << ShapeUtil::HumanString(gather_shape);
2543 }
2544
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherCollapsedDimension)2545 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
2546 TF_ASSERT_OK_AND_ASSIGN(
2547 Shape gather_shape,
2548 ShapeInference::InferGatherShape(
2549 ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
2550 ShapeUtil::MakeShape(S64, {}),
2551 HloGatherInstruction::MakeGatherDimNumbers(
2552 /*offset_dims=*/{0, 1},
2553 /*collapsed_slice_dims=*/{0},
2554 /*start_index_map=*/{0},
2555 /*index_vector_dim=*/0),
2556 /*slice_sizes=*/{1, 2, 1}));
2557 EXPECT_TRUE(ShapeUtil::Equal(
2558 gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
2559 << ShapeUtil::HumanString(gather_shape);
2560 }
2561
TEST_F(ScatterGatherShapeInferenceTest,DynamicIndices)2562 TEST_F(ScatterGatherShapeInferenceTest, DynamicIndices) {
2563 TF_ASSERT_OK_AND_ASSIGN(
2564 Shape gather_shape,
2565 ShapeInference::InferGatherShape(
2566 ShapeUtil::MakeShape(F32, {3, 2, 2}),
2567 ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
2568 HloGatherInstruction::MakeGatherDimNumbers(
2569 /*offset_dims=*/{2, 3},
2570 /*collapsed_slice_dims=*/{0},
2571 /*start_index_map=*/{0, 1},
2572 /*index_vector_dim=*/2),
2573 /*slice_sizes=*/{1, 2, 2}));
2574 EXPECT_TRUE(ShapeUtil::Equal(
2575 gather_shape,
2576 ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
2577 << ShapeUtil::HumanString(gather_shape);
2578 }
2579
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)2580 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
2581 TF_ASSERT_OK_AND_ASSIGN(
2582 Shape gather_shape,
2583 ShapeInference::InferGatherShape(
2584 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2585 HloGatherInstruction::MakeGatherDimNumbers(
2586 /*offset_dims=*/{4, 5, 6, 7, 8},
2587 /*collapsed_slice_dims=*/{},
2588 /*start_index_map=*/{0, 1, 2, 3, 4},
2589 /*index_vector_dim=*/2),
2590 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2591
2592 EXPECT_TRUE(ShapeUtil::Equal(
2593 gather_shape,
2594 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2595 << ShapeUtil::HumanString(gather_shape);
2596 }
2597
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)2598 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
2599 TF_ASSERT_OK_AND_ASSIGN(
2600 Shape gather_shape,
2601 ShapeInference::InferGatherShape(
2602 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2603 HloGatherInstruction::MakeGatherDimNumbers(
2604 /*offset_dims=*/{4, 5, 6, 7, 8},
2605 /*collapsed_slice_dims=*/{},
2606 /*start_index_map=*/{0, 1, 2, 3, 4},
2607 /*index_vector_dim=*/0),
2608 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2609
2610 EXPECT_TRUE(ShapeUtil::Equal(
2611 gather_shape,
2612 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2613 << ShapeUtil::HumanString(gather_shape);
2614 }
2615
TEST_F(ScatterGatherShapeInferenceTest,NoOutputGatherDims)2616 TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
2617 // This is equivalent to a dynamic slice.
2618 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2619 ShapeInference::InferGatherShape(
2620 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2621 HloGatherInstruction::MakeGatherDimNumbers(
2622 /*offset_dims=*/{0, 1, 2, 3, 4},
2623 /*collapsed_slice_dims=*/{},
2624 /*start_index_map=*/{0, 1, 2, 3, 4},
2625 /*index_vector_dim=*/0),
2626 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2627
2628 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2629 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
2630 << ShapeUtil::HumanString(gather_shape);
2631 }
2632
TEST_F(ScatterGatherShapeInferenceTest,ScalarGatherIndices)2633 TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
2634 // The gather indices "tensor" is a scalar S here that's used to slice out
2635 // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
2636 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2637 ShapeInference::InferGatherShape(
2638 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2639 HloGatherInstruction::MakeGatherDimNumbers(
2640 /*offset_dims=*/{0, 1, 2, 3},
2641 /*collapsed_slice_dims=*/{0},
2642 /*start_index_map=*/{0},
2643 /*index_vector_dim=*/0),
2644 /*slice_sizes=*/{1, 30, 29, 28, 27}));
2645
2646 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2647 ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
2648 << ShapeUtil::HumanString(gather_shape);
2649 }
2650
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedTensorInput)2651 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
2652 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2653 tuple_shape_, s64_vector_32_,
2654 HloGatherInstruction::MakeGatherDimNumbers(
2655 /*offset_dims=*/{0},
2656 /*collapsed_slice_dims=*/{1},
2657 /*start_index_map=*/{1},
2658 /*index_vector_dim=*/1),
2659 /*slice_sizes=*/{64, 1});
2660 ASSERT_FALSE(statusor.ok());
2661 EXPECT_THAT(statusor.status().error_message(),
2662 HasSubstr("Expected array argument for input"))
2663 << statusor.status();
2664 }
2665
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedGatherIndicesInput)2666 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
2667 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2668 s64_vector_32_, tuple_shape_,
2669 HloGatherInstruction::MakeGatherDimNumbers(
2670 /*offset_dims=*/{0},
2671 /*collapsed_slice_dims=*/{1},
2672 /*start_index_map=*/{1},
2673 /*index_vector_dim=*/0),
2674 /*slice_sizes=*/{64, 1});
2675 ASSERT_FALSE(statusor.ok());
2676 EXPECT_THAT(statusor.status().error_message(),
2677 HasSubstr("Expected array argument for gather indices"))
2678 << statusor.status();
2679 }
2680
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointGatherIndicesInput)2681 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
2682 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2683 s64_vector_32_, vector_32_,
2684 HloGatherInstruction::MakeGatherDimNumbers(
2685 /*offset_dims=*/{0},
2686 /*collapsed_slice_dims=*/{1},
2687 /*start_index_map=*/{1},
2688 /*index_vector_dim=*/0),
2689 /*slice_sizes=*/{64, 1});
2690 ASSERT_FALSE(statusor.ok());
2691 EXPECT_THAT(statusor.status().error_message(),
2692 HasSubstr("Gather indices parameter must be an integral tensor"))
2693 << statusor.status();
2694 }
2695
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)2696 TEST_F(ScatterGatherShapeInferenceTest,
2697 InvalidGatherDimNumbers_NonAscendingWindowIndices) {
2698 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2699 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2700 HloGatherInstruction::MakeGatherDimNumbers(
2701 /*offset_dims=*/{4, 5, 6, 8, 7},
2702 /*collapsed_slice_dims=*/{},
2703 /*start_index_map=*/{0, 1, 2, 3, 4},
2704 /*index_vector_dim=*/4),
2705 /*slice_sizes=*/{30, 29, 28, 27, 26});
2706 ASSERT_FALSE(statusor.ok());
2707 EXPECT_THAT(
2708 statusor.status().error_message(),
2709 HasSubstr("Output window dimensions in gather op must be ascending"))
2710 << statusor.status();
2711 }
2712
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)2713 TEST_F(ScatterGatherShapeInferenceTest,
2714 InvalidGatherDimNumbers_RepeatedWindowIndices) {
2715 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2716 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2717 HloGatherInstruction::MakeGatherDimNumbers(
2718 /*offset_dims=*/{4, 5, 6, 7, 7},
2719 /*collapsed_slice_dims=*/{},
2720 /*start_index_map=*/{0, 1, 2, 3, 4},
2721 /*index_vector_dim=*/4),
2722 /*slice_sizes=*/{30, 29, 28, 27, 26});
2723 ASSERT_FALSE(statusor.ok());
2724 EXPECT_THAT(
2725 statusor.status().error_message(),
2726 HasSubstr("Output window dimensions in gather op must not repeat"))
2727 << statusor.status();
2728 }
2729
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)2730 TEST_F(ScatterGatherShapeInferenceTest,
2731 InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
2732 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2733 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2734 HloGatherInstruction::MakeGatherDimNumbers(
2735 /*offset_dims=*/{4, 5, 99, 100, 101},
2736 /*collapsed_slice_dims=*/{},
2737 /*start_index_map=*/{0, 1, 2, 3, 4},
2738 /*index_vector_dim=*/4),
2739 /*slice_sizes=*/{30, 29, 28, 27, 26});
2740 ASSERT_FALSE(statusor.ok());
2741 EXPECT_THAT(statusor.status().error_message(),
2742 HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2743 << statusor.status();
2744 }
2745
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2746 TEST_F(ScatterGatherShapeInferenceTest,
2747 InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2748 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2749 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2750 HloGatherInstruction::MakeGatherDimNumbers(
2751 /*offset_dims=*/{4, 5, 6, 7, 9},
2752 /*collapsed_slice_dims=*/{},
2753 /*start_index_map=*/{0, 1, 2, 3, 4},
2754 /*index_vector_dim=*/4),
2755 /*slice_sizes=*/{30, 29, 28, 27, 26});
2756 ASSERT_FALSE(statusor.ok());
2757 EXPECT_THAT(statusor.status().error_message(),
2758 HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2759 << statusor.status();
2760 }
2761
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2762 TEST_F(ScatterGatherShapeInferenceTest,
2763 InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2764 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2765 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2766 HloGatherInstruction::MakeGatherDimNumbers(
2767 /*offset_dims=*/{4, 5, 6, 7, 8},
2768 /*collapsed_slice_dims=*/{4},
2769 /*start_index_map=*/{0, 1, 2, 3, 4},
2770 /*index_vector_dim=*/4),
2771 /*slice_sizes=*/{30, 29, 28, 27, 26});
2772 ASSERT_FALSE(statusor.ok());
2773 EXPECT_THAT(
2774 statusor.status().error_message(),
2775 HasSubstr("All components of the offset index in a gather op must either "
2776 "be a offset dimension or explicitly collapsed"))
2777 << statusor.status();
2778 }
2779
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2780 TEST_F(ScatterGatherShapeInferenceTest,
2781 InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2782 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2783 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2784 HloGatherInstruction::MakeGatherDimNumbers(
2785 /*offset_dims=*/{4, 5, 6, 7, 8},
2786 /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2787 /*start_index_map=*/{0, 1, 2, 3, 4},
2788 /*index_vector_dim=*/4),
2789 /*slice_sizes=*/{30, 29, 28, 27, 26});
2790 ASSERT_FALSE(statusor.ok());
2791 EXPECT_THAT(statusor.status().error_message(),
2792 HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2793 "range is [0, 5), got: 19"))
2794 << statusor.status();
2795 }
2796
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2797 TEST_F(ScatterGatherShapeInferenceTest,
2798 InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2799 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2800 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2801 HloGatherInstruction::MakeGatherDimNumbers(
2802 /*offset_dims=*/{4, 5, 6, 7, 8},
2803 /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2804 /*start_index_map=*/{0, 1, 2, 3, 4},
2805 /*index_vector_dim=*/4),
2806 /*slice_sizes=*/{30, 29, 28, 27, 26});
2807 ASSERT_FALSE(statusor.ok());
2808 EXPECT_THAT(statusor.status().error_message(),
2809 HasSubstr("Repeated dimensions not allowed in "
2810 "collapsed_slice_dims in gather op"))
2811 << statusor.status();
2812 }
2813
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2814 TEST_F(ScatterGatherShapeInferenceTest,
2815 InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2816 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2817 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2818 HloGatherInstruction::MakeGatherDimNumbers(
2819 /*offset_dims=*/{4, 5, 6, 7, 8},
2820 /*collapsed_slice_dims=*/{},
2821 /*start_index_map=*/{0, 1, 2, 3},
2822 /*index_vector_dim=*/4),
2823 /*slice_sizes=*/{30, 29, 28, 27, 26});
2824 ASSERT_FALSE(statusor.ok());
2825 EXPECT_THAT(statusor.status().error_message(),
2826 HasSubstr("Gather op has 4 elements in start_index_map and "
2827 "the bound of dimension index_vector_dim=4 of "
2828 "start_indices is 5. These two numbers must be equal."))
2829 << statusor.status();
2830 }
2831
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2832 TEST_F(ScatterGatherShapeInferenceTest,
2833 InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2834 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2835 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2836 HloGatherInstruction::MakeGatherDimNumbers(
2837 /*offset_dims=*/{4, 5, 6, 7, 8},
2838 /*collapsed_slice_dims=*/{},
2839 /*start_index_map=*/{0, 1, 2, 3, 7},
2840 /*index_vector_dim=*/4),
2841 /*slice_sizes=*/{30, 29, 28, 27, 26});
2842 ASSERT_FALSE(statusor.ok());
2843 EXPECT_THAT(statusor.status().error_message(),
2844 HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2845 << statusor.status();
2846 }
2847
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2848 TEST_F(ScatterGatherShapeInferenceTest,
2849 InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2850 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2851 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2852 HloGatherInstruction::MakeGatherDimNumbers(
2853 /*offset_dims=*/{4, 5, 6, 7, 8},
2854 /*collapsed_slice_dims=*/{},
2855 /*start_index_map=*/{0, 1, 2, 3, 3},
2856 /*index_vector_dim=*/4),
2857 /*slice_sizes=*/{30, 29, 28, 27, 26});
2858 ASSERT_FALSE(statusor.ok());
2859 EXPECT_THAT(
2860 statusor.status().error_message(),
2861 HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2862 << statusor.status();
2863 }
2864
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2865 TEST_F(ScatterGatherShapeInferenceTest,
2866 InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2867 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2868 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2869 HloGatherInstruction::MakeGatherDimNumbers(
2870 /*offset_dims=*/{4, 5, 6, 7, 8},
2871 /*collapsed_slice_dims=*/{2, 1},
2872 /*start_index_map=*/{0, 1, 2, 3, 4},
2873 /*index_vector_dim=*/4),
2874 /*slice_sizes=*/{1, 1, 28, 27, 26});
2875 ASSERT_FALSE(statusor.ok());
2876 EXPECT_THAT(statusor.status().error_message(),
2877 HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2878 << statusor.status();
2879 }
2880
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2881 TEST_F(ScatterGatherShapeInferenceTest,
2882 InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2883 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2884 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2885 HloGatherInstruction::MakeGatherDimNumbers(
2886 /*offset_dims=*/{4, 5, 6, 7},
2887 /*collapsed_slice_dims=*/{2},
2888 /*start_index_map=*/{0, 1, 2, 3, 4},
2889 /*index_vector_dim=*/4),
2890 /*slice_sizes=*/{30, 29, 1, 300, 26});
2891 ASSERT_FALSE(statusor.ok());
2892 EXPECT_THAT(statusor.status().error_message(),
2893 HasSubstr("Slice size at index 3 in gather op is out of range, "
2894 "must be within [0, 48), got 300."))
2895 << statusor.status();
2896 }
2897
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2898 TEST_F(ScatterGatherShapeInferenceTest,
2899 InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2900 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2901 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2902 HloGatherInstruction::MakeGatherDimNumbers(
2903 /*offset_dims=*/{4, 5, 6, 7, 8},
2904 /*collapsed_slice_dims=*/{},
2905 /*start_index_map=*/{0, 1, 2, 3, 4},
2906 /*index_vector_dim=*/4),
2907 /*slice_sizes=*/{30, 29, 28, 26});
2908 ASSERT_FALSE(statusor.ok());
2909 EXPECT_THAT(
2910 statusor.status().error_message(),
2911 HasSubstr("Gather op must have one slice size for every input dimension"))
2912 << statusor.status();
2913 }
2914
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2915 TEST_F(ScatterGatherShapeInferenceTest,
2916 InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2917 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2918 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2919 HloGatherInstruction::MakeGatherDimNumbers(
2920 /*offset_dims=*/{4, 5, 6, 7},
2921 /*collapsed_slice_dims=*/{1},
2922 /*start_index_map=*/{0, 1, 2, 3, 4},
2923 /*index_vector_dim=*/4),
2924 /*slice_sizes=*/{30, 29, 28, 26, 20});
2925 ASSERT_FALSE(statusor.ok());
2926 EXPECT_THAT(
2927 statusor.status().error_message(),
2928 HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
2929 "but bound is 29 for index 1 at position 0."))
2930 << statusor.status();
2931 }
2932
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2933 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2934 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2935 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2936 HloGatherInstruction::MakeGatherDimNumbers(
2937 /*offset_dims=*/{4, 5, 6, 7, 8},
2938 /*collapsed_slice_dims=*/{},
2939 /*start_index_map=*/{0, 1, 2, 3, 4},
2940 /*index_vector_dim=*/32),
2941 /*slice_sizes=*/{30, 29, 28, 27, 26});
2942
2943 ASSERT_FALSE(statusor.ok());
2944 EXPECT_THAT(statusor.status().error_message(),
2945 HasSubstr("Gather index leaf dimension must be within [0, "
2946 "rank(start_indices) + 1)"))
2947 << statusor.status();
2948 }
2949
2950 // Shape inference tests for Scatter.
2951
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdates)2952 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
2953 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2954 ShapeInference::InferScatterShape(
2955 matrix_64_48_, s64_vector_32_,
2956 ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
2957 HloScatterInstruction::MakeScatterDimNumbers(
2958 /*update_window_dims=*/{0},
2959 /*inserted_window_dims=*/{1},
2960 /*scatter_dims_to_operand_dims=*/{1},
2961 /*index_vector_dim=*/1)));
2962 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2963 << ShapeUtil::HumanString(scatter_shape);
2964 }
2965
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdatesV2)2966 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
2967 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2968 ShapeInference::InferScatterShape(
2969 matrix_64_48_, s64_vector_32_,
2970 ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
2971 HloScatterInstruction::MakeScatterDimNumbers(
2972 /*update_window_dims=*/{1},
2973 /*inserted_window_dims=*/{0},
2974 /*scatter_dims_to_operand_dims=*/{0},
2975 /*index_vector_dim=*/1)));
2976 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2977 << ShapeUtil::HumanString(scatter_shape);
2978 }
2979
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdates)2980 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
2981 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2982 ShapeInference::InferScatterShape(
2983 matrix_64_48_, s64_vector_32_,
2984 ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
2985 HloScatterInstruction::MakeScatterDimNumbers(
2986 /*update_window_dims=*/{0},
2987 /*inserted_window_dims=*/{1},
2988 /*scatter_dims_to_operand_dims=*/{1},
2989 /*index_vector_dim=*/1)));
2990 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2991 << ShapeUtil::HumanString(scatter_shape);
2992 }
2993
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdatesV2)2994 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
2995 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2996 ShapeInference::InferScatterShape(
2997 matrix_64_48_, s64_vector_32_,
2998 ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
2999 HloScatterInstruction::MakeScatterDimNumbers(
3000 /*update_window_dims=*/{1},
3001 /*inserted_window_dims=*/{0},
3002 /*scatter_dims_to_operand_dims=*/{0},
3003 /*index_vector_dim=*/1)));
3004 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3005 << ShapeUtil::HumanString(scatter_shape);
3006 }
3007
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)3008 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
3009 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3010 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
3011 to_apply_,
3012 HloScatterInstruction::MakeScatterDimNumbers(
3013 /*update_window_dims=*/{0},
3014 /*inserted_window_dims=*/{1},
3015 /*scatter_dims_to_operand_dims=*/{1},
3016 /*index_vector_dim=*/1));
3017 ASSERT_FALSE(statusor.ok());
3018 EXPECT_THAT(
3019 statusor.status().error_message(),
3020 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3021 "the bounds of the corresponding dimensions of operand."))
3022 << statusor.status();
3023 }
3024
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)3025 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
3026 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3027 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
3028 to_apply_,
3029 HloScatterInstruction::MakeScatterDimNumbers(
3030 /*update_window_dims=*/{1},
3031 /*inserted_window_dims=*/{0},
3032 /*scatter_dims_to_operand_dims=*/{1},
3033 /*index_vector_dim=*/1));
3034 ASSERT_FALSE(statusor.ok());
3035 EXPECT_THAT(
3036 statusor.status().error_message(),
3037 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3038 "the bounds of the corresponding dimensions of operand."))
3039 << statusor.status();
3040 }
3041
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)3042 TEST_F(ScatterGatherShapeInferenceTest,
3043 TfScatterWithUpdatesNotMatchingIndices) {
3044 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3045 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
3046 to_apply_,
3047 HloScatterInstruction::MakeScatterDimNumbers(
3048 /*update_window_dims=*/{0},
3049 /*inserted_window_dims=*/{1},
3050 /*scatter_dims_to_operand_dims=*/{1},
3051 /*index_vector_dim=*/1));
3052 ASSERT_FALSE(statusor.ok());
3053 EXPECT_THAT(
3054 statusor.status().error_message(),
3055 HasSubstr(
3056 "Bounds of the scatter dimensions of updates must be same as the "
3057 "bounds of the corresponding dimensions of scatter indices."))
3058 << statusor.status();
3059 }
3060
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)3061 TEST_F(ScatterGatherShapeInferenceTest,
3062 TfScatterWithUpdatesNotMatchingIndicesV2) {
3063 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3064 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
3065 to_apply_,
3066 HloScatterInstruction::MakeScatterDimNumbers(
3067 /*update_window_dims=*/{1},
3068 /*inserted_window_dims=*/{0},
3069 /*scatter_dims_to_operand_dims=*/{1},
3070 /*index_vector_dim=*/1));
3071 ASSERT_FALSE(statusor.ok());
3072 EXPECT_THAT(
3073 statusor.status().error_message(),
3074 HasSubstr(
3075 "Bounds of the scatter dimensions of updates must be same as the "
3076 "bounds of the corresponding dimensions of scatter indices."))
3077 << statusor.status();
3078 }
3079
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdates)3080 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
3081 TF_ASSERT_OK_AND_ASSIGN(
3082 Shape scatter_shape,
3083 ShapeInference::InferScatterShape(
3084 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3085 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
3086 HloScatterInstruction::MakeScatterDimNumbers(
3087 /*update_window_dims=*/{4},
3088 /*inserted_window_dims=*/{0},
3089 /*scatter_dims_to_operand_dims=*/{0},
3090 /*index_vector_dim=*/4)));
3091 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3092 << ShapeUtil::HumanString(scatter_shape);
3093 }
3094
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdatesV2)3095 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
3096 TF_ASSERT_OK_AND_ASSIGN(
3097 Shape scatter_shape,
3098 ShapeInference::InferScatterShape(
3099 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3100 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
3101 HloScatterInstruction::MakeScatterDimNumbers(
3102 /*update_window_dims=*/{4},
3103 /*inserted_window_dims=*/{1},
3104 /*scatter_dims_to_operand_dims=*/{0},
3105 /*index_vector_dim=*/4)));
3106 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3107 << ShapeUtil::HumanString(scatter_shape);
3108 }
3109
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdates)3110 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
3111 TF_ASSERT_OK_AND_ASSIGN(
3112 Shape scatter_shape,
3113 ShapeInference::InferScatterShape(
3114 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3115 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
3116 HloScatterInstruction::MakeScatterDimNumbers(
3117 /*update_window_dims=*/{4},
3118 /*inserted_window_dims=*/{0},
3119 /*scatter_dims_to_operand_dims=*/{0},
3120 /*index_vector_dim=*/4)));
3121 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3122 << ShapeUtil::HumanString(scatter_shape);
3123 }
3124
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)3125 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
3126 TF_ASSERT_OK_AND_ASSIGN(
3127 Shape scatter_shape,
3128 ShapeInference::InferScatterShape(
3129 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3130 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
3131 HloScatterInstruction::MakeScatterDimNumbers(
3132 /*update_window_dims=*/{4},
3133 /*inserted_window_dims=*/{1},
3134 /*scatter_dims_to_operand_dims=*/{0},
3135 /*index_vector_dim=*/4)));
3136 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3137 << ShapeUtil::HumanString(scatter_shape);
3138 }
3139
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)3140 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
3141 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3142 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3143 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
3144 HloScatterInstruction::MakeScatterDimNumbers(
3145 /*update_window_dims=*/{4},
3146 /*inserted_window_dims=*/{1},
3147 /*scatter_dims_to_operand_dims=*/{0},
3148 /*index_vector_dim=*/4));
3149 ASSERT_FALSE(statusor.ok());
3150 EXPECT_THAT(
3151 statusor.status().error_message(),
3152 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3153 "the bounds of the corresponding dimensions of operand."))
3154 << statusor.status();
3155 }
3156
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)3157 TEST_F(ScatterGatherShapeInferenceTest,
3158 TfScatterNdWithUpdatesNotMatchingIndices) {
3159 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3160 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3161 ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
3162 HloScatterInstruction::MakeScatterDimNumbers(
3163 /*update_window_dims=*/{4},
3164 /*inserted_window_dims=*/{1},
3165 /*scatter_dims_to_operand_dims=*/{0},
3166 /*index_vector_dim=*/4));
3167 ASSERT_FALSE(statusor.ok());
3168 EXPECT_THAT(
3169 statusor.status().error_message(),
3170 HasSubstr(
3171 "Bounds of the scatter dimensions of updates must be same as the "
3172 "bounds of the corresponding dimensions of scatter indices."))
3173 << statusor.status();
3174 }
3175
TEST_F(ScatterGatherShapeInferenceTest,TfBatchDynamicUpdateSlice)3176 TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
3177 TF_ASSERT_OK_AND_ASSIGN(
3178 Shape scatter_shape,
3179 ShapeInference::InferScatterShape(
3180 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3181 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
3182 to_apply_,
3183 HloScatterInstruction::MakeScatterDimNumbers(
3184 /*update_window_dims=*/{4, 5, 6, 7, 8},
3185 /*inserted_window_dims=*/{},
3186 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3187 /*index_vector_dim=*/4)));
3188 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3189 << ShapeUtil::HumanString(scatter_shape);
3190 }
3191
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDim)3192 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
3193 TF_ASSERT_OK_AND_ASSIGN(
3194 Shape scatter_shape,
3195 ShapeInference::InferScatterShape(
3196 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
3197 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3198 to_apply_,
3199 HloScatterInstruction::MakeScatterDimNumbers(
3200 /*update_window_dims=*/{4, 5, 6, 7, 8},
3201 /*inserted_window_dims=*/{},
3202 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3203 /*index_vector_dim=*/2)));
3204
3205 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3206 << ShapeUtil::HumanString(scatter_shape);
3207 }
3208
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)3209 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
3210 TF_ASSERT_OK_AND_ASSIGN(
3211 Shape scatter_shape,
3212 ShapeInference::InferScatterShape(
3213 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
3214 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3215 to_apply_,
3216 HloScatterInstruction::MakeScatterDimNumbers(
3217 /*update_window_dims=*/{4, 5, 6, 7, 8},
3218 /*inserted_window_dims=*/{},
3219 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3220 /*index_vector_dim=*/0)));
3221
3222 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3223 << ShapeUtil::HumanString(scatter_shape);
3224 }
3225
TEST_F(ScatterGatherShapeInferenceTest,NoUpdateScatterDims)3226 TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
3227 // This is equivalent to a dynamic update slice.
3228 TF_ASSERT_OK_AND_ASSIGN(
3229 Shape scatter_shape,
3230 ShapeInference::InferScatterShape(
3231 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
3232 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
3233 HloScatterInstruction::MakeScatterDimNumbers(
3234 /*update_window_dims=*/{0, 1, 2, 3, 4},
3235 /*inserted_window_dims=*/{},
3236 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3237 /*index_vector_dim=*/0)));
3238
3239 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3240 << ShapeUtil::HumanString(scatter_shape);
3241 }
3242
TEST_F(ScatterGatherShapeInferenceTest,ScalarScatterIndices)3243 TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
3244 // The scalar indices "tensor" is a scalar S here that's used to update a
3245 // [30,29,28,27] shaped tensor within the operand at position S.
3246 TF_ASSERT_OK_AND_ASSIGN(
3247 Shape scatter_shape,
3248 ShapeInference::InferScatterShape(
3249 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3250 ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3251 HloScatterInstruction::MakeScatterDimNumbers(
3252 /*update_window_dims=*/{0, 1, 2, 3},
3253 /*inserted_window_dims=*/{0},
3254 /*scatter_dims_to_operand_dims=*/{0},
3255 /*index_vector_dim=*/0)));
3256
3257 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3258 << ShapeUtil::HumanString(scatter_shape);
3259 }
3260
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedTensorInput)3261 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
3262 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3263 tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
3264 HloScatterInstruction::MakeScatterDimNumbers(
3265 /*update_window_dims=*/{0},
3266 /*inserted_window_dims=*/{1},
3267 /*scatter_dims_to_operand_dims=*/{1},
3268 /*index_vector_dim=*/1));
3269 ASSERT_FALSE(statusor.ok());
3270 EXPECT_THAT(statusor.status().error_message(),
3271 HasSubstr("Expected array argument for operand"))
3272 << statusor.status();
3273 }
3274
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)3275 TEST_F(ScatterGatherShapeInferenceTest,
3276 ScatterWithTupleShapedScatterIndicesInput) {
3277 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3278 s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
3279 HloScatterInstruction::MakeScatterDimNumbers(
3280 /*update_window_dims=*/{0},
3281 /*inserted_window_dims=*/{1},
3282 /*scatter_dims_to_operand_dims=*/{1},
3283 /*index_vector_dim=*/0));
3284 ASSERT_FALSE(statusor.ok());
3285 EXPECT_THAT(statusor.status().error_message(),
3286 HasSubstr("Expected array argument for scatter indices"))
3287 << statusor.status();
3288 }
3289
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)3290 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
3291 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3292 s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
3293 HloScatterInstruction::MakeScatterDimNumbers(
3294 /*update_window_dims=*/{0},
3295 /*inserted_window_dims=*/{1},
3296 /*scatter_dims_to_operand_dims=*/{1},
3297 /*index_vector_dim=*/0));
3298 ASSERT_FALSE(statusor.ok());
3299 EXPECT_THAT(statusor.status().error_message(),
3300 HasSubstr("Expected array argument for updates"))
3301 << statusor.status();
3302 }
3303
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointScatterIndicesInput)3304 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
3305 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3306 s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
3307 HloScatterInstruction::MakeScatterDimNumbers(
3308 /*update_window_dims=*/{0},
3309 /*inserted_window_dims=*/{1},
3310 /*scatter_dims_to_operand_dims=*/{1},
3311 /*index_vector_dim=*/0));
3312 ASSERT_FALSE(statusor.ok());
3313 EXPECT_THAT(statusor.status().error_message(),
3314 HasSubstr("Scatter indices parameter must be an integral tensor"))
3315 << statusor.status();
3316 }
3317
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)3318 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
3319 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3320 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3321 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3322 HloScatterInstruction::MakeScatterDimNumbers(
3323 /*update_window_dims=*/{4, 5, 6},
3324 /*inserted_window_dims=*/{1, 2},
3325 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3326 /*index_vector_dim=*/10));
3327 ASSERT_FALSE(statusor.ok());
3328 EXPECT_THAT(statusor.status().error_message(),
3329 HasSubstr("Scatter index leaf dimension must be within [0, "
3330 "rank(scatter_indices) + 1)"))
3331 << statusor.status();
3332 }
3333
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdates)3334 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
3335 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3336 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3337 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
3338 HloScatterInstruction::MakeScatterDimNumbers(
3339 /*update_window_dims=*/{4, 5, 6},
3340 /*inserted_window_dims=*/{1, 2},
3341 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3342 /*index_vector_dim=*/4));
3343 ASSERT_FALSE(statusor.ok());
3344 EXPECT_THAT(statusor.status().error_message(),
3345 HasSubstr("Updates tensor must be of rank 7; got 8."))
3346 << statusor.status();
3347 }
3348
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdateComputation)3349 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
3350 const ProgramShape invalid_update_computation =
3351 ShapeUtil::MakeProgramShape({f32_}, f32_);
3352 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3353 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3354 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
3355 invalid_update_computation,
3356 HloScatterInstruction::MakeScatterDimNumbers(
3357 /*update_window_dims=*/{4, 5, 6},
3358 /*inserted_window_dims=*/{1, 2},
3359 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3360 /*index_vector_dim=*/4));
3361 ASSERT_FALSE(statusor.ok());
3362 EXPECT_THAT(
3363 statusor.status().error_message(),
3364 HasSubstr("Reduction function must take 2 parameters, but takes 1"))
3365 << statusor.status();
3366 }
3367
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)3368 TEST_F(ScatterGatherShapeInferenceTest,
3369 InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
3370 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3371 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3372 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3373 HloScatterInstruction::MakeScatterDimNumbers(
3374 /*update_window_dims=*/{4, 5, 6, 8, 7},
3375 /*inserted_window_dims=*/{},
3376 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3377 /*index_vector_dim=*/4));
3378 ASSERT_FALSE(statusor.ok());
3379 EXPECT_THAT(statusor.status().error_message(),
3380 HasSubstr("update_window_dims in scatter op must be sorted"))
3381 << statusor.status();
3382 }
3383
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)3384 TEST_F(ScatterGatherShapeInferenceTest,
3385 InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
3386 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3387 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3388 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3389 HloScatterInstruction::MakeScatterDimNumbers(
3390 /*update_window_dims=*/{4, 5, 6, 7, 7},
3391 /*inserted_window_dims=*/{},
3392 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3393 /*index_vector_dim=*/4));
3394 ASSERT_FALSE(statusor.ok());
3395 EXPECT_THAT(statusor.status().error_message(),
3396 HasSubstr("update_window_dims in scatter op must not repeat"))
3397 << statusor.status();
3398 }
3399
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)3400 TEST_F(ScatterGatherShapeInferenceTest,
3401 InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
3402 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3403 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3404 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3405 HloScatterInstruction::MakeScatterDimNumbers(
3406 /*update_window_dims=*/{4, 5, 6, 7, 9},
3407 /*inserted_window_dims=*/{},
3408 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3409 /*index_vector_dim=*/4));
3410 ASSERT_FALSE(statusor.ok());
3411 EXPECT_THAT(statusor.status().error_message(),
3412 HasSubstr("Invalid update_window_dims set in scatter op; valid "
3413 "range is [0, 9)"))
3414 << statusor.status();
3415 }
3416
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)3417 TEST_F(ScatterGatherShapeInferenceTest,
3418 InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
3419 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3420 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3421 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3422 HloScatterInstruction::MakeScatterDimNumbers(
3423 /*update_window_dims=*/{4, 5, 6},
3424 /*inserted_window_dims=*/{2, 1},
3425 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3426 /*index_vector_dim=*/4));
3427 ASSERT_FALSE(statusor.ok());
3428 EXPECT_THAT(statusor.status().error_message(),
3429 HasSubstr("inserted_window_dims in scatter op must be sorted"))
3430 << statusor.status();
3431 }
3432
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)3433 TEST_F(ScatterGatherShapeInferenceTest,
3434 InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
3435 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3436 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3437 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3438 HloScatterInstruction::MakeScatterDimNumbers(
3439 /*update_window_dims=*/{4, 5, 6},
3440 /*inserted_window_dims=*/{1, 1},
3441 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3442 /*index_vector_dim=*/4));
3443 ASSERT_FALSE(statusor.ok());
3444 EXPECT_THAT(statusor.status().error_message(),
3445 HasSubstr("inserted_window_dims in scatter op must not repeat"))
3446 << statusor.status();
3447 }
3448
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)3449 TEST_F(ScatterGatherShapeInferenceTest,
3450 InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
3451 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3452 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3453 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3454 HloScatterInstruction::MakeScatterDimNumbers(
3455 /*update_window_dims=*/{4, 5, 6},
3456 /*inserted_window_dims=*/{1, 5},
3457 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3458 /*index_vector_dim=*/4));
3459 ASSERT_FALSE(statusor.ok());
3460 EXPECT_THAT(statusor.status().error_message(),
3461 HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
3462 "range is [0, 5)"))
3463 << statusor.status();
3464 }
3465
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)3466 TEST_F(ScatterGatherShapeInferenceTest,
3467 InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
3468 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3469 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3470 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3471 HloScatterInstruction::MakeScatterDimNumbers(
3472 /*update_window_dims=*/{4, 5, 6},
3473 /*inserted_window_dims=*/{1, 2},
3474 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
3475 /*index_vector_dim=*/4));
3476 ASSERT_FALSE(statusor.ok());
3477 EXPECT_THAT(
3478 statusor.status().error_message(),
3479 HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
3480 "the bound of dimension index_vector_dim=4 of scatter_indices "
3481 "is 5. These two numbers must be equal"))
3482 << statusor.status();
3483 }
3484
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)3485 TEST_F(ScatterGatherShapeInferenceTest,
3486 InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
3487 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3488 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3489 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3490 HloScatterInstruction::MakeScatterDimNumbers(
3491 /*update_window_dims=*/{4, 5, 6},
3492 /*inserted_window_dims=*/{1, 2},
3493 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
3494 /*index_vector_dim=*/4));
3495 ASSERT_FALSE(statusor.ok());
3496 EXPECT_THAT(statusor.status().error_message(),
3497 HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
3498 "is [0, 5), got: 4->10"))
3499 << statusor.status();
3500 }
3501
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)3502 TEST_F(ScatterGatherShapeInferenceTest,
3503 InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
3504 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3505 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3506 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3507 HloScatterInstruction::MakeScatterDimNumbers(
3508 /*update_window_dims=*/{4, 5, 6},
3509 /*inserted_window_dims=*/{1, 2},
3510 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
3511 /*index_vector_dim=*/4));
3512 ASSERT_FALSE(statusor.ok());
3513 EXPECT_THAT(
3514 statusor.status().error_message(),
3515 HasSubstr(
3516 "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
3517 << statusor.status();
3518 }
3519
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)3520 TEST_F(ScatterGatherShapeInferenceTest,
3521 InvalidScatterDimNumbers_InsufficientWindowDims) {
3522 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3523 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3524 ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3525 HloScatterInstruction::MakeScatterDimNumbers(
3526 /*update_window_dims=*/{0, 1, 2, 3},
3527 /*inserted_window_dims=*/{},
3528 /*scatter_dims_to_operand_dims=*/{0},
3529 /*index_vector_dim=*/0));
3530 ASSERT_FALSE(statusor.ok());
3531 EXPECT_THAT(
3532 statusor.status().error_message(),
3533 HasSubstr(
3534 "Scatter op has window of size 4; doesn't match operand of rank 5."))
3535 << statusor.status();
3536 }
3537
3538 } // namespace
3539 } // namespace xla
3540