1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/padding.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29
30 namespace xla {
31 namespace {
32
33 using ::testing::ContainsRegex;
34 using ::testing::HasSubstr;
35
36 class ShapeInferenceTest : public ::testing::Test {
37 protected:
38 // Some handy scalar shapes.
39 const Shape s32_ = ShapeUtil::MakeShape(S32, {});
40 const Shape f16_ = ShapeUtil::MakeShape(F16, {});
41 const Shape f32_ = ShapeUtil::MakeShape(F32, {});
42 const Shape f64_ = ShapeUtil::MakeShape(F64, {});
43 const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
44
45 // Some handy vector and matrix shapes of F32 type.
46 // Suffix: vector_length_, matrix_rows_cols_
47 const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
48 const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
49 const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
50 const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
51 const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
52
53 // Some handy S32 arrays.
54 const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
55 };
56
57 // Subclass for testing InferReduceShape.
58 class ReduceShapeInferenceTest : public ShapeInferenceTest {
59 protected:
60 // Helper that runs reduce shape inference with the input 'arg' and given
61 // dimensions to reduce, and checks the inferred shape is as expected. The
62 // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64> dimensions_to_reduce)63 void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
64 const Shape& arg,
65 absl::Span<const int64> dimensions_to_reduce) {
66 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
67 auto inferred_status = ShapeInference::InferReduceShape(
68 {&arg, &f32_}, dimensions_to_reduce, to_apply);
69 EXPECT_IS_OK(inferred_status.status());
70 EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
71 inferred_status.ValueOrDie()));
72 }
73 };
74
75 // Subclass for testing InferSelectAndScatterShape.
76 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
77 protected:
SelectAndScatterShapeInferenceTest()78 SelectAndScatterShapeInferenceTest() {
79 operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
80 source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
81 WindowDimension dim;
82 dim.set_size(2);
83 dim.set_stride(2);
84 dim.set_padding_low(0);
85 dim.set_padding_high(0);
86 dim.set_window_dilation(1);
87 dim.set_base_dilation(1);
88 *window_.add_dimensions() = dim;
89 *window_.add_dimensions() = dim;
90 init_value_shape_ = ShapeUtil::MakeShape(F32, {});
91 select_program_shape_ = ShapeUtil::MakeProgramShape(
92 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
93 scatter_program_shape_ = ShapeUtil::MakeProgramShape(
94 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
95 }
96
97 Shape operand_shape_;
98 Shape source_shape_;
99 Window window_;
100 Shape init_value_shape_;
101 ProgramShape select_program_shape_;
102 ProgramShape scatter_program_shape_;
103 };
104
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)105 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
106 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
107 auto inferred_status =
108 ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
109 ASSERT_IS_OK(inferred_status.status());
110 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
111 }
112
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)113 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
114 Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
115 auto inferred_status = ShapeInference::InferTernaryOpShape(
116 HloOpcode::kSelect, pred_, tuple, tuple);
117 ASSERT_FALSE(inferred_status.ok());
118 ASSERT_THAT(inferred_status.status().error_message(),
119 HasSubstr("Expected array argument for select"));
120 }
121
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)122 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
123 auto inferred_status = ShapeInference::InferTernaryOpShape(
124 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
125 ASSERT_FALSE(inferred_status.ok());
126 ASSERT_THAT(
127 inferred_status.status().error_message(),
128 HasSubstr("Operands to select and predicate must be the same shape"));
129 }
130
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)131 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
132 auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
133 auto inferred_status = ShapeInference::InferTernaryOpShape(
134 HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
135 ASSERT_IS_OK(inferred_status.status());
136 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
137 }
138
TEST_F(ShapeInferenceTest,SelectBadShapes)139 TEST_F(ShapeInferenceTest, SelectBadShapes) {
140 auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
141 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
142 ASSERT_FALSE(inferred_status_error1.ok());
143 ASSERT_THAT(inferred_status_error1.status().error_message(),
144 HasSubstr("Operands to select must be the same shape"));
145
146 auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
147 HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
148 ASSERT_FALSE(inferred_status_error2.ok());
149 ASSERT_THAT(inferred_status_error2.status().error_message(),
150 HasSubstr("pred operand must have PRED"));
151
152 auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
153 HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
154 matrix_64_48_);
155 ASSERT_FALSE(inferred_status_error3.ok());
156 ASSERT_THAT(
157 inferred_status_error3.status().error_message(),
158 HasSubstr("Operands to select and predicate must be the same shape"));
159
160 // Tuples have a TUPLE element type and cannot be the pred of a select.
161 auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
162 HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
163 ShapeUtil::MakeTupleShape({f32_, f32_}),
164 ShapeUtil::MakeTupleShape({f32_, f32_}));
165 ASSERT_FALSE(inferred_status_error4.ok());
166 ASSERT_THAT(inferred_status_error4.status().error_message(),
167 HasSubstr("Expected array argument for select pred"));
168 }
169
TEST_F(ShapeInferenceTest,ClampAllMatrix)170 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
171 auto inferred_status = ShapeInference::InferTernaryOpShape(
172 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
173 ASSERT_IS_OK(inferred_status.status());
174 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
175 }
176
TEST_F(ShapeInferenceTest,ClampAllScalar)177 TEST_F(ShapeInferenceTest, ClampAllScalar) {
178 auto inferred_status =
179 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
180 ASSERT_IS_OK(inferred_status.status());
181 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
182 }
183
TEST_F(ShapeInferenceTest,ClampMinScalar)184 TEST_F(ShapeInferenceTest, ClampMinScalar) {
185 auto inferred_status = ShapeInference::InferTernaryOpShape(
186 HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
187 ASSERT_FALSE(inferred_status.ok());
188 ASSERT_THAT(inferred_status.status().error_message(),
189 HasSubstr("Clamp with different shapes"));
190 }
191
TEST_F(ShapeInferenceTest,ClampMaxScalar)192 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
193 auto inferred_status = ShapeInference::InferTernaryOpShape(
194 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
195 ASSERT_FALSE(inferred_status.ok());
196 ASSERT_THAT(inferred_status.status().error_message(),
197 HasSubstr("Clamp with different shapes"));
198 }
199
TEST_F(ShapeInferenceTest,ClampOperandScalar)200 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
201 auto inferred_status = ShapeInference::InferTernaryOpShape(
202 HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
203 ASSERT_FALSE(inferred_status.ok());
204 ASSERT_THAT(inferred_status.status().error_message(),
205 HasSubstr("Clamp with different shapes"));
206 }
207
TEST_F(ShapeInferenceTest,ClampMinMatrix)208 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
209 auto inferred_status = ShapeInference::InferTernaryOpShape(
210 HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
211 ASSERT_FALSE(inferred_status.ok());
212 ASSERT_THAT(inferred_status.status().error_message(),
213 HasSubstr("Clamp with different shapes"));
214 }
215
TEST_F(ShapeInferenceTest,ClampMaxMatrix)216 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
217 auto inferred_status = ShapeInference::InferTernaryOpShape(
218 HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
219 ASSERT_FALSE(inferred_status.ok());
220 ASSERT_THAT(inferred_status.status().error_message(),
221 HasSubstr("Clamp with different shapes"));
222 }
223
TEST_F(ShapeInferenceTest,ClampOperandMatrix)224 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
225 auto inferred_status = ShapeInference::InferTernaryOpShape(
226 HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
227 ASSERT_FALSE(inferred_status.ok());
228 ASSERT_THAT(inferred_status.status().error_message(),
229 HasSubstr("Clamp with different shapes"));
230 }
231
TEST_F(ShapeInferenceTest,ClampBadShapes)232 TEST_F(ShapeInferenceTest, ClampBadShapes) {
233 // Type mismatch
234 ASSERT_FALSE(
235 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
236 .ok());
237 ASSERT_FALSE(
238 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
239 .ok());
240 ASSERT_FALSE(
241 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
242 .ok());
243 // Dimension mismatch
244 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
245 HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
246 .ok());
247 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
248 HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
249 .ok());
250 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
251 HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
252 .ok());
253 // Dimension mismatch, where one operand is a scalar
254 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
255 vector_64_, vector_32_, f32_)
256 .ok());
257 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
258 vector_64_, f32_, vector_32_)
259 .ok());
260 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
261 vector_64_, vector_32_)
262 .ok());
263 }
264
TEST_F(ShapeInferenceTest,Complex)265 TEST_F(ShapeInferenceTest, Complex) {
266 auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
267 absl::Span<const int64> bcast) {
268 return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
269 bcast);
270 };
271 // Inputs must be FP.
272 ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
273 ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
274 // Component types must match.
275 ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
276 // Only F32->C64 and F64->C128 supported.
277 ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
278 // Validate correct uses.
279 Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
280 TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
281 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
282 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
283 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
284 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
285 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
286 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
287 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
288
289 Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
290 TF_ASSERT_OK_AND_ASSIGN(result,
291 complex_shape(vector_64_, matrix_32_64_, {1}));
292 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
293 TF_ASSERT_OK_AND_ASSIGN(result,
294 complex_shape(matrix_32_64_, vector_64_, {1}));
295 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
296 TF_ASSERT_OK_AND_ASSIGN(result,
297 complex_shape(matrix_32_64_, matrix_32_64_, {}));
298 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
299 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
300 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
301
302 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
303 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
304 }
305
TEST_F(ShapeInferenceTest,VariadicOpTuplify)306 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
307 StatusOr<Shape> result =
308 ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
309 ASSERT_IS_OK(result.status());
310 ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
311 ShapeUtil::MakeTupleShape({s32_, f32_})));
312 }
313
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)314 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
315 Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
316 Window window;
317 WindowDimension dim;
318 dim.set_size(2);
319 dim.set_stride(2);
320 dim.set_padding_low(0);
321 dim.set_padding_high(0);
322 dim.set_window_dilation(1);
323 dim.set_base_dilation(1);
324 *window.add_dimensions() = dim;
325 *window.add_dimensions() = dim;
326 Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
327 Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
328 Shape float_scalar = ShapeUtil::MakeShape(F32, {});
329 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
330 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
331 auto inferred_status = ShapeInference::InferReduceWindowShape(
332 matrix_shape, init_value_shape, window, to_apply);
333
334 ASSERT_IS_OK(inferred_status.status());
335 Shape inferred = inferred_status.ValueOrDie();
336 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
337 }
338
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)339 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
340 auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
341 operand_shape_, select_program_shape_, window_, source_shape_,
342 init_value_shape_, scatter_program_shape_);
343 ASSERT_IS_OK(inferred_status_ok.status());
344 Shape inferred = inferred_status_ok.ValueOrDie();
345 ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
346 }
347
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)348 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
349 Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
350 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
351 operand_shape_, select_program_shape_, window_, source_shape_fail,
352 init_value_shape_, scatter_program_shape_);
353 ASSERT_FALSE(inferred_status_fail.ok());
354 ASSERT_THAT(inferred_status_fail.status().error_message(),
355 HasSubstr("Source shape does not match"));
356 }
357
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)358 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
359 ProgramShape select_program_shape_fail =
360 ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
361 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
362 operand_shape_, select_program_shape_fail, window_, source_shape_,
363 init_value_shape_, scatter_program_shape_);
364 ASSERT_FALSE(inferred_status_fail.ok());
365 ASSERT_THAT(inferred_status_fail.status().error_message(),
366 HasSubstr("Select function must take 2 parameters"));
367 }
368
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)369 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
370 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
371 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
372 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
373 operand_shape_, select_program_shape_fail, window_, source_shape_,
374 init_value_shape_, scatter_program_shape_);
375 ASSERT_FALSE(inferred_status_fail.ok());
376 ASSERT_THAT(inferred_status_fail.status().error_message(),
377 HasSubstr("Select function must have rank-0 PRED"));
378 }
379
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)380 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
381 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
382 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
383 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
384 operand_shape_, select_program_shape_fail, window_, source_shape_,
385 init_value_shape_, scatter_program_shape_);
386 ASSERT_FALSE(inferred_status_fail.ok());
387 ASSERT_THAT(inferred_status_fail.status().error_message(),
388 HasSubstr("Select function's first parameter"));
389 }
390
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)391 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
392 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
393 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
394 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
395 operand_shape_, select_program_shape_fail, window_, source_shape_,
396 init_value_shape_, scatter_program_shape_);
397 ASSERT_FALSE(inferred_status_fail.ok());
398 ASSERT_THAT(inferred_status_fail.status().error_message(),
399 HasSubstr("Select function's second parameter"));
400 }
401
TEST_F(ShapeInferenceTest,Convolve)402 TEST_F(ShapeInferenceTest, Convolve) {
403 ConvolutionDimensionNumbers dnums;
404
405 // Dimension order: batch, feature, x0, x1
406 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
407 dnums.set_input_batch_dimension(0);
408 dnums.set_output_batch_dimension(0);
409 dnums.set_input_feature_dimension(1);
410 dnums.set_output_feature_dimension(1);
411 dnums.add_input_spatial_dimensions(2);
412 dnums.add_output_spatial_dimensions(2);
413 dnums.add_input_spatial_dimensions(3);
414 dnums.add_output_spatial_dimensions(3);
415
416 // Dimension order: x1, batch, feature, x0
417 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
418 dnums.set_kernel_input_feature_dimension(2);
419 dnums.set_kernel_output_feature_dimension(1);
420 dnums.add_kernel_spatial_dimensions(3);
421 dnums.add_kernel_spatial_dimensions(0);
422
423 Window window;
424 auto dim0 = window.add_dimensions();
425 auto dim1 = window.add_dimensions();
426 dim0->set_size(3);
427 dim0->set_stride(2);
428 dim0->set_padding_low(1);
429 dim0->set_padding_high(1);
430 dim0->set_window_dilation(1);
431 dim0->set_base_dilation(1);
432 dim1->set_size(2);
433 dim1->set_stride(1);
434 dim1->set_padding_low(0);
435 dim1->set_padding_high(0);
436 dim1->set_window_dilation(1);
437 dim1->set_base_dilation(1);
438 auto inferred_status = ShapeInference::InferConvolveShape(
439 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
440 window, dnums, /*preferred_element_type=*/absl::nullopt);
441 ASSERT_IS_OK(inferred_status.status());
442 Shape inferred_shape = inferred_status.ValueOrDie();
443 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
444 inferred_shape));
445 }
446
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)447 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
448 ConvolutionDimensionNumbers dnums;
449
450 // Dimension order: batch, feature, x0, x1
451 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
452 dnums.set_input_batch_dimension(0);
453 dnums.set_output_batch_dimension(0);
454 dnums.set_input_feature_dimension(1);
455 dnums.set_output_feature_dimension(1);
456 dnums.add_input_spatial_dimensions(2);
457 dnums.add_output_spatial_dimensions(2);
458 dnums.add_input_spatial_dimensions(3);
459 dnums.add_output_spatial_dimensions(3);
460
461 // Dimension order: x1, batch, feature, x0
462 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
463 dnums.set_kernel_input_feature_dimension(2);
464 dnums.set_kernel_output_feature_dimension(1);
465 dnums.add_kernel_spatial_dimensions(3);
466 dnums.add_kernel_spatial_dimensions(0);
467
468 Window window;
469 auto dim0 = window.add_dimensions();
470 dim0->set_size(3);
471 dim0->set_stride(3);
472 dim0->set_padding_low(0);
473 dim0->set_padding_high(0);
474 dim0->set_window_dilation(6);
475 dim0->set_base_dilation(1);
476
477 auto dim1 = window.add_dimensions();
478 dim1->set_size(2);
479 dim1->set_stride(1);
480 dim1->set_padding_low(2);
481 dim1->set_padding_high(1);
482 dim1->set_window_dilation(2);
483 dim1->set_base_dilation(1);
484 auto inferred_status = ShapeInference::InferConvolveShape(
485 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
486 window, dnums, /*preferred_element_type=*/absl::nullopt);
487 ASSERT_IS_OK(inferred_status.status());
488 Shape inferred_shape = inferred_status.ValueOrDie();
489 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
490 inferred_shape));
491 }
492
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)493 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
494 ConvolutionDimensionNumbers dnums;
495
496 // Dimension order: batch, feature, x0, x1
497 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
498 dnums.set_input_batch_dimension(0);
499 dnums.set_output_batch_dimension(0);
500 dnums.set_input_feature_dimension(1);
501 dnums.set_output_feature_dimension(1);
502 dnums.add_input_spatial_dimensions(2);
503 dnums.add_output_spatial_dimensions(2);
504 dnums.add_input_spatial_dimensions(3);
505 dnums.add_output_spatial_dimensions(3);
506
507 // Dimension order: x1, batch, feature, x0
508 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
509 dnums.set_kernel_input_feature_dimension(2);
510 dnums.set_kernel_output_feature_dimension(1);
511 dnums.add_kernel_spatial_dimensions(3);
512 dnums.add_kernel_spatial_dimensions(0);
513
514 Window window;
515 auto dim0 = window.add_dimensions();
516 dim0->set_size(4);
517 dim0->set_stride(3);
518 dim0->set_padding_low(0);
519 dim0->set_padding_high(0);
520 dim0->set_window_dilation(1);
521 dim0->set_base_dilation(6);
522
523 auto dim1 = window.add_dimensions();
524 dim1->set_size(2);
525 dim1->set_stride(1);
526 dim1->set_padding_low(2);
527 dim1->set_padding_high(1);
528 dim1->set_window_dilation(1);
529 dim1->set_base_dilation(2);
530 auto inferred_status = ShapeInference::InferConvolveShape(
531 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
532 window, dnums, /*preferred_element_type=*/absl::nullopt);
533 ASSERT_IS_OK(inferred_status.status());
534 Shape inferred_shape = inferred_status.ValueOrDie();
535 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
536 inferred_shape));
537 }
538
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)539 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
540 // Dimension order for this test: batch, feature, x0, x1
541 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
542 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
543
544 ConvolutionDimensionNumbers dnums;
545 dnums.set_input_batch_dimension(3);
546 dnums.set_output_batch_dimension(3);
547 dnums.set_input_feature_dimension(2);
548 dnums.set_output_feature_dimension(2);
549 dnums.add_input_spatial_dimensions(0);
550 dnums.add_output_spatial_dimensions(0);
551 dnums.add_input_spatial_dimensions(1);
552 dnums.add_output_spatial_dimensions(1);
553 dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
554 dnums.set_kernel_output_feature_dimension(3);
555 dnums.add_kernel_spatial_dimensions(0);
556 dnums.add_kernel_spatial_dimensions(1);
557
558 Window window;
559 auto dim0 = window.add_dimensions();
560 auto dim1 = window.add_dimensions();
561 dim0->set_size(2);
562 dim0->set_stride(1);
563 dim0->set_padding_low(0);
564 dim0->set_padding_high(0);
565 dim1->set_size(3);
566 dim1->set_stride(2);
567 dim1->set_padding_low(1);
568 dim1->set_padding_high(1);
569 auto inferred_status = ShapeInference::InferConvolveShape(
570 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
571 window, dnums, /*preferred_element_type=*/absl::nullopt);
572 ASSERT_FALSE(inferred_status.ok());
573 ASSERT_THAT(inferred_status.status().error_message(),
574 HasSubstr("each dimension exactly once"));
575 }
576
TEST_F(ShapeInferenceTest,ConvolveBatchGroupCountUnequalOutputFeature)577 TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
578 ConvolutionDimensionNumbers dnums;
579 dnums.set_input_batch_dimension(0);
580 dnums.set_input_feature_dimension(1);
581 dnums.add_input_spatial_dimensions(2);
582 dnums.add_input_spatial_dimensions(3);
583 dnums.set_kernel_input_feature_dimension(0);
584 dnums.set_kernel_output_feature_dimension(1);
585 dnums.add_kernel_spatial_dimensions(2);
586 dnums.add_kernel_spatial_dimensions(3);
587 dnums.set_output_batch_dimension(0);
588 dnums.set_output_feature_dimension(1);
589 dnums.add_output_spatial_dimensions(2);
590 dnums.add_output_spatial_dimensions(3);
591 Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
592 Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
593 Window window;
594 auto dim0 = window.add_dimensions();
595 auto dim1 = window.add_dimensions();
596 dim0->set_size(4);
597 dim1->set_size(4);
598 dim0->set_padding_low(0);
599 dim0->set_padding_high(2);
600 dim1->set_padding_low(2);
601 dim1->set_padding_high(1);
602 dim0->set_stride(1);
603 dim1->set_stride(1);
604 dim0->set_window_dilation(3);
605 dim1->set_window_dilation(2);
606 auto inferred_status = ShapeInference::InferConvolveShape(
607 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
608 window, dnums, /*preferred_element_type=*/absl::nullopt);
609 ASSERT_FALSE(inferred_status.ok());
610 ASSERT_THAT(inferred_status.status().error_message(),
611 HasSubstr("to be a multiple of batch group count"));
612 }
613
614 struct ConvolveArgs {
615 Shape lhs_shape;
616 Shape rhs_shape;
617 ConvolutionDimensionNumbers dnums;
618 Window window;
619 };
620
MakeConvolveArgs(PrimitiveType lhs_type,PrimitiveType rhs_type)621 ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
622 ConvolveArgs args;
623 ConvolutionDimensionNumbers& dnums = args.dnums;
624
625 // Dimension order: batch, feature, x0, x1
626 args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
627 dnums.set_input_batch_dimension(0);
628 dnums.set_output_batch_dimension(0);
629 dnums.set_input_feature_dimension(1);
630 dnums.set_output_feature_dimension(1);
631 dnums.add_input_spatial_dimensions(2);
632 dnums.add_output_spatial_dimensions(2);
633 dnums.add_input_spatial_dimensions(3);
634 dnums.add_output_spatial_dimensions(3);
635
636 // Dimension order: x1, batch, feature, x0
637 args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
638 dnums.set_kernel_input_feature_dimension(2);
639 dnums.set_kernel_output_feature_dimension(1);
640 dnums.add_kernel_spatial_dimensions(3);
641 dnums.add_kernel_spatial_dimensions(0);
642
643 auto dim0 = args.window.add_dimensions();
644 auto dim1 = args.window.add_dimensions();
645 dim0->set_size(3);
646 dim0->set_stride(2);
647 dim0->set_padding_low(1);
648 dim0->set_padding_high(1);
649 dim0->set_window_dilation(1);
650 dim0->set_base_dilation(1);
651 dim1->set_size(2);
652 dim1->set_stride(1);
653 dim1->set_padding_low(0);
654 dim1->set_padding_high(0);
655 dim1->set_window_dilation(1);
656 dim1->set_base_dilation(1);
657 return args;
658 }
659
TEST_F(ShapeInferenceTest,ConvolveWithBF16_F16)660 TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
661 ConvolveArgs args = MakeConvolveArgs(BF16, F16);
662 TF_ASSERT_OK_AND_ASSIGN(
663 Shape inferred_shape,
664 ShapeInference::InferConvolveShape(
665 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
666 /*batch_group_count=*/1, args.window, args.dnums,
667 /*preferred_element_type=*/absl::nullopt))
668 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
669 inferred_shape));
670 }
671
TEST_F(ShapeInferenceTest,ConvolveWithF16_BF16)672 TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
673 ConvolveArgs args = MakeConvolveArgs(F16, BF16);
674 TF_ASSERT_OK_AND_ASSIGN(
675 Shape inferred_shape,
676 ShapeInference::InferConvolveShape(
677 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
678 /*batch_group_count=*/1, args.window, args.dnums,
679 /*preferred_element_type=*/absl::nullopt))
680 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
681 inferred_shape));
682 }
683
TEST_F(ShapeInferenceTest,ConvolveWithS32_U32)684 TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
685 ConvolveArgs args = MakeConvolveArgs(S32, U32);
686 TF_ASSERT_OK_AND_ASSIGN(
687 Shape inferred_shape,
688 ShapeInference::InferConvolveShape(
689 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
690 /*batch_group_count=*/1, args.window, args.dnums,
691 /*preferred_element_type=*/absl::nullopt))
692 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
693 inferred_shape));
694 }
695
TEST_F(ShapeInferenceTest,ConvolveWithU32_S32)696 TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
697 ConvolveArgs args = MakeConvolveArgs(U32, S32);
698 TF_ASSERT_OK_AND_ASSIGN(
699 Shape inferred_shape,
700 ShapeInference::InferConvolveShape(
701 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
702 /*batch_group_count=*/1, args.window, args.dnums,
703 /*preferred_element_type=*/absl::nullopt))
704 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
705 inferred_shape));
706 }
707
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementType)708 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
709 ConvolveArgs args = MakeConvolveArgs(S8, S16);
710 TF_ASSERT_OK_AND_ASSIGN(
711 Shape inferred_shape,
712 ShapeInference::InferConvolveShape(
713 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
714 /*batch_group_count=*/1, args.window, args.dnums,
715 /*preferred_element_type=*/S16))
716 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
717 inferred_shape));
718 }
719
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeSameAsInferredType)720 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
721 ConvolveArgs args = MakeConvolveArgs(S8, S16);
722 TF_ASSERT_OK_AND_ASSIGN(
723 Shape inferred_shape,
724 ShapeInference::InferConvolveShape(
725 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
726 /*batch_group_count=*/1, args.window, args.dnums,
727 /*preferred_element_type=*/S32))
728 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
729 inferred_shape));
730 }
731
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithNarrowerPreferredElementType)732 TEST_F(ShapeInferenceTest,
733 FloatingPointConvolveWithNarrowerPreferredElementType) {
734 ConvolveArgs args = MakeConvolveArgs(F32, F32);
735 TF_ASSERT_OK_AND_ASSIGN(
736 Shape inferred_shape,
737 ShapeInference::InferConvolveShape(
738 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
739 /*batch_group_count=*/1, args.window, args.dnums,
740 /*preferred_element_type=*/BF16))
741 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
742 inferred_shape));
743 }
744
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithInvalidPreferredElementType)745 TEST_F(ShapeInferenceTest,
746 FloatingPointConvolveWithInvalidPreferredElementType) {
747 ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
748 auto inferred_status =
749 ShapeInference::InferConvolveShape(
750 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
751 /*batch_group_count=*/1, args.window, args.dnums,
752 /*preferred_element_type=*/S32)
753 .status();
754 ASSERT_FALSE(inferred_status.ok());
755 ASSERT_THAT(inferred_status.error_message(),
756 HasSubstr("must both be integral or both be floating point"));
757 }
758
TEST_F(ShapeInferenceTest,IntegralConvolveWithFloatingPointPreferredElementType)759 TEST_F(ShapeInferenceTest,
760 IntegralConvolveWithFloatingPointPreferredElementType) {
761 ConvolveArgs args = MakeConvolveArgs(S8, S16);
762 auto inferred_status =
763 ShapeInference::InferConvolveShape(
764 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
765 /*batch_group_count=*/1, args.window, args.dnums,
766 /*preferred_element_type=*/F32)
767 .status();
768 ASSERT_FALSE(inferred_status.ok());
769 ASSERT_THAT(inferred_status.error_message(),
770 HasSubstr("must both be integral or both be floating point"));
771 }
772
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeWithDifferentSignedness)773 TEST_F(ShapeInferenceTest,
774 ConvolveWithPreferredElementTypeWithDifferentSignedness) {
775 ConvolveArgs args = MakeConvolveArgs(S8, S16);
776 auto inferred_status =
777 ShapeInference::InferConvolveShape(
778 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
779 /*batch_group_count=*/1, args.window, args.dnums,
780 /*preferred_element_type=*/U32)
781 .status();
782 ASSERT_FALSE(inferred_status.ok());
783 ASSERT_THAT(inferred_status.error_message(),
784 HasSubstr("must have the same signedness as the original type"));
785 }
786
TEST_F(ShapeInferenceTest,ConvolveWithNarrowerPreferredElementType)787 TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
788 ConvolveArgs args = MakeConvolveArgs(S8, S16);
789 auto inferred_status =
790 ShapeInference::InferConvolveShape(
791 args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
792 /*batch_group_count=*/1, args.window, args.dnums,
793 /*preferred_element_type=*/S8)
794 .status();
795 ASSERT_FALSE(inferred_status.ok());
796 ASSERT_THAT(inferred_status.error_message(),
797 HasSubstr("must not be narrower than the original type"));
798 }
799
800 namespace fft {
801
802 static const char* unsupported_rank = "only supports ranks 1-3";
803 static const char* invalid_rank = "requires input of at least same rank";
804 static const char* requires_complex_input = "requires complex input type";
805 static const char* requires_f32_input = "requires F32 or F64 input type";
806 static const char* dimensions_match = "innermost dimensions match fft_length";
807 static const char* innermost_dimension_matches =
808 "innermost dimension matches fft_length/2+1";
809
Pass(const Shape & shape,FftType type,absl::Span<const int64> length,const Shape & expected_shape)810 static void Pass(const Shape& shape, FftType type,
811 absl::Span<const int64> length, const Shape& expected_shape) {
812 auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
813 ASSERT_IS_OK(inferred_status.status());
814 Shape inferred_shape = inferred_status.ValueOrDie();
815 ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
816 }
817
Fail(const Shape & shape,FftType type,absl::Span<const int64> length,absl::string_view message)818 static void Fail(const Shape& shape, FftType type,
819 absl::Span<const int64> length, absl::string_view message) {
820 auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
821 ASSERT_FALSE(inferred_status.ok());
822 ASSERT_THAT(inferred_status.status().error_message(),
823 HasSubstr(std::string(message)));
824 }
825
826 } // namespace fft
827
TEST_F(ShapeInferenceTest,InferFftShapeTestFftRanks)828 TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
829 FftType type = FftType::FFT;
830 Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
831 fft::Fail(shape, type, {}, fft::unsupported_rank);
832 fft::Pass(shape, type, {8}, shape);
833 fft::Pass(shape, type, {16, 8}, shape);
834 fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
835 fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
836 }
837
TEST_F(ShapeInferenceTest,InferFftShapeTestFftTypes)838 TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
839 FftType type = FftType::FFT;
840 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
841 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
842 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
843 fft::Pass(shape_c128, type, {16, 8}, shape_c128);
844 }
845
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftRanks)846 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
847 FftType type = FftType::IFFT;
848 Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
849 fft::Fail(shape, type, {}, fft::unsupported_rank);
850 fft::Pass(shape, type, {8}, shape);
851 fft::Pass(shape, type, {16, 8}, shape);
852 fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
853 fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
854 }
855
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftTypes)856 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
857 FftType type = FftType::IFFT;
858 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
859 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
860 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
861 fft::Pass(shape_c128, type, {16, 8}, shape_c128);
862 }
863
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftRanks)864 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
865 FftType type = FftType::RFFT;
866 Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
867 Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
868 fft::Fail(shape_in, type, {}, fft::unsupported_rank);
869 fft::Pass(shape_in, type, {8}, shape_out);
870 fft::Pass(shape_in, type, {16, 8}, shape_out);
871 fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
872 fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
873 }
874
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftDimensions)875 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
876 FftType type = FftType::RFFT;
877 Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
878 fft::Fail(shape, type, {4}, fft::dimensions_match);
879 fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
880 fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
881 fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
882
883 Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
884 Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
885 fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
886 fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
887
888 Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
889 Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
890 Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
891 fft::Pass(even_shape_in, type, {16, 8}, shape_out);
892 fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
893 }
894
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftTypes)895 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
896 FftType type = FftType::RFFT;
897 Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
898 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
899 fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
900 fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
901 }
902
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftRanks)903 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
904 FftType type = FftType::IRFFT;
905 Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
906 Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
907 fft::Fail(shape_in, type, {}, fft::unsupported_rank);
908 fft::Pass(shape_in, type, {8}, shape_out);
909 fft::Pass(shape_in, type, {16, 8}, shape_out);
910 fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
911 fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
912 }
913
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftDimensions)914 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
915 FftType type = FftType::IRFFT;
916 Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
917 fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
918 fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
919 fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
920 fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
921
922 Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
923 Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
924 fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
925 fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
926
927 Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
928 Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
929 fft::Pass(shape, type, {16, 8}, even_shape_out);
930 fft::Pass(shape, type, {16, 9}, odd_shape_out);
931 }
932
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftTypes)933 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
934 FftType type = FftType::IRFFT;
935 Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
936 Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
937 Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
938 fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
939 fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
940 }
941
TEST_F(ShapeInferenceTest,MapThatChangesElementType)942 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
943 Shape arg = ShapeUtil::MakeShape(F32, {20});
944 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
945 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
946 EXPECT_IS_OK(inferred_status.status());
947 Shape expected = ShapeUtil::MakeShape(S32, {20});
948 EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
949 }
950
TEST_F(ShapeInferenceTest,Map)951 TEST_F(ShapeInferenceTest, Map) {
952 auto inferred_status_r1f32 = ShapeInference::InferMapShape(
953 {&vector_32_, &vector_32_},
954 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
955 EXPECT_IS_OK(inferred_status_r1f32.status());
956 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
957
958 // It's OK to provide a single argument, as long as the applied arity matches
959 // (this degenerates to a Map).
960 auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
961 {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
962 EXPECT_IS_OK(inferred_status_r1f32_one.status());
963 EXPECT_TRUE(
964 ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
965
966 auto inferred_status_r2s32 = ShapeInference::InferMapShape(
967 {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
968 ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
969 EXPECT_IS_OK(inferred_status_r2s32.status());
970 EXPECT_TRUE(
971 ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
972
973 auto no_args_error = ShapeInference::InferMapShape(
974 {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
975 ASSERT_FALSE(no_args_error.ok());
976 ASSERT_THAT(no_args_error.status().error_message(),
977 HasSubstr("expects at least one argument"));
978
979 auto args_diff_shapes_error = ShapeInference::InferMapShape(
980 {&vector_32_, &vector_64_},
981 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
982 ASSERT_FALSE(args_diff_shapes_error.ok());
983 ASSERT_THAT(args_diff_shapes_error.status().error_message(),
984 HasSubstr("requires all operands to have the same shape"));
985
986 auto arity_error = ShapeInference::InferMapShape(
987 {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
988 {0});
989 ASSERT_FALSE(arity_error.ok());
990 ASSERT_THAT(arity_error.status().error_message(),
991 HasSubstr("function arity must match"));
992
993 auto output_shape_error = ShapeInference::InferMapShape(
994 {&vector_32_, &vector_32_},
995 ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
996 ASSERT_FALSE(output_shape_error.ok());
997 ASSERT_THAT(output_shape_error.status().error_message(),
998 HasSubstr("result has to be a scalar"));
999
1000 auto param_shape_error = ShapeInference::InferMapShape(
1001 {&vector_32_, &vector_32_},
1002 ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
1003 ASSERT_FALSE(param_shape_error.ok());
1004 ASSERT_THAT(param_shape_error.status().error_message(),
1005 HasSubstr("parameter has to be a scalar"));
1006
1007 auto param_element_type_error = ShapeInference::InferMapShape(
1008 {&vector_32_, &vector_32_},
1009 ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
1010 ASSERT_FALSE(param_element_type_error.ok());
1011 ASSERT_THAT(param_element_type_error.status().error_message(),
1012 HasSubstr("parameter type has to match argument"));
1013
1014 Shape arg = ShapeUtil::MakeShape(F32, {20});
1015 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
1016 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
1017 EXPECT_IS_OK(inferred_status.status());
1018 EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
1019
1020 auto inferred_status_error1 = ShapeInference::InferMapShape(
1021 {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1022 ASSERT_FALSE(inferred_status_error1.ok());
1023 ASSERT_THAT(inferred_status_error1.status().error_message(),
1024 HasSubstr("arity must match number of arguments"));
1025
1026 auto inferred_status_error2 = ShapeInference::InferMapShape(
1027 {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
1028 ASSERT_FALSE(inferred_status_error2.ok());
1029 ASSERT_THAT(inferred_status_error2.status().error_message(),
1030 HasSubstr("has to be a scalar"));
1031
1032 auto inferred_status_error3 = ShapeInference::InferMapShape(
1033 {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
1034 ASSERT_FALSE(inferred_status_error3.ok());
1035 ASSERT_THAT(inferred_status_error3.status().error_message(),
1036 HasSubstr("has to be a scalar"));
1037
1038 auto inferred_status_error5 = ShapeInference::InferMapShape(
1039 {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
1040 ASSERT_FALSE(inferred_status_error5.ok());
1041 ASSERT_THAT(inferred_status_error5.status().error_message(),
1042 HasSubstr("parameter type has to match argument"));
1043 }
1044
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)1045 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
1046 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
1047 /*dimensions_to_reduce=*/{0});
1048 }
1049
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)1050 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
1051 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
1052 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1053 /*dimensions_to_reduce=*/{0});
1054 }
1055
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)1056 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
1057 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
1058 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1059 /*dimensions_to_reduce=*/{1});
1060 }
1061
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)1062 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
1063 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
1064 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1065 /*dimensions_to_reduce=*/{0, 1});
1066 }
1067
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)1068 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
1069 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
1070 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1071 /*dimensions_to_reduce=*/{1, 2});
1072 }
1073
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)1074 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
1075 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1076 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1077 /*dimensions_to_reduce=*/{0, 2});
1078
1079 // Check that the order of dimensions_to_reduce doesn't matter.
1080 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1081 ShapeUtil::MakeShape(F32, {2, 3, 4}),
1082 /*dimensions_to_reduce=*/{2, 0});
1083 }
1084
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)1085 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
1086 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
1087 /*dimensions_to_reduce=*/{0, 1, 2});
1088 }
1089
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)1090 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
1091 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1092 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1093 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1094 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1095 auto inferred_status = ShapeInference::InferReduceShape(
1096 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1097 EXPECT_IS_OK(inferred_status.status());
1098 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
1099 inferred_status.ValueOrDie()));
1100 }
1101
TEST_F(ReduceShapeInferenceTest,ReduceWindowMultiOutput)1102 TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
1103 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1104 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1105 std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1106 std::vector<const Shape*> inits = {&f32_, &s32_};
1107 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1108 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1109 std::vector<int64> window_dimensions = {1, 2, 4};
1110 std::vector<int64> window_strides = {1, 1, 1};
1111 std::vector<std::pair<int64, int64>> padding_values =
1112 MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1113 window_strides, Padding::kValid);
1114 TF_ASSERT_OK_AND_ASSIGN(
1115 Window window,
1116 ShapeInference::InferWindowFromDimensions(
1117 window_dimensions, window_strides, padding_values, {}, {}));
1118 auto inferred_status = ShapeInference::InferReduceWindowShape(
1119 absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1120 VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
1121 EXPECT_IS_OK(inferred_status.status());
1122 EXPECT_TRUE(ShapeUtil::Equal(
1123 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
1124 ShapeUtil::MakeShape(S32, {5, 2, 0})}),
1125 inferred_status.ValueOrDie()));
1126 }
1127
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)1128 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
1129 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1130 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1131 ProgramShape to_apply =
1132 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
1133 ShapeUtil::MakeTupleShape({f32_, s32_}));
1134 auto inferred_status = ShapeInference::InferReduceShape(
1135 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1136 EXPECT_FALSE(inferred_status.ok());
1137 EXPECT_THAT(inferred_status.status().error_message(),
1138 HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
1139 }
1140
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)1141 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
1142 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1143 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1144 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1145 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1146 auto inferred_status = ShapeInference::InferReduceShape(
1147 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1148 EXPECT_FALSE(inferred_status.ok());
1149 EXPECT_THAT(
1150 inferred_status.status().error_message(),
1151 HasSubstr(
1152 "parameter shape differs from the result shape: s32[] vs f32[]"));
1153 }
1154
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)1155 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
1156 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1157 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1158 auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
1159 EXPECT_FALSE(inferred_status.ok());
1160 EXPECT_THAT(inferred_status.status().error_message(),
1161 HasSubstr("must have at least 2 arguments, has 0"));
1162 }
1163
TEST_F(ReduceShapeInferenceTest,ErrorBadReduceWindowInput)1164 TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
1165 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1166 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1167 std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1168 std::vector<const Shape*> inits = {&f32_, &s32_};
1169 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1170 {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1171 std::vector<int64> window_dimensions = {1, 2, 4};
1172 std::vector<int64> window_strides = {1, 1, 1};
1173 std::vector<std::pair<int64, int64>> padding_values =
1174 MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1175 window_strides, Padding::kValid);
1176 TF_ASSERT_OK_AND_ASSIGN(
1177 Window window,
1178 ShapeInference::InferWindowFromDimensions(
1179 window_dimensions, window_strides, padding_values, {}, {}));
1180 auto inferred_status = ShapeInference::InferReduceWindowShape(
1181 absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1182 EXPECT_FALSE(inferred_status.status().ok());
1183 EXPECT_THAT(inferred_status.status().error_message(),
1184 HasSubstr("f32[] vs s32[]"));
1185 }
1186
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)1187 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
1188 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1189 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1190 ProgramShape to_apply =
1191 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
1192 auto inferred_status = ShapeInference::InferReduceShape(
1193 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1194 EXPECT_FALSE(inferred_status.ok());
1195 EXPECT_THAT(
1196 inferred_status.status().error_message(),
1197 HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
1198 }
1199
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)1200 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
1201 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1202 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1203 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1204 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
1205 auto inferred_status = ShapeInference::InferReduceShape(
1206 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1207 EXPECT_FALSE(inferred_status.ok());
1208 EXPECT_THAT(
1209 inferred_status.status().error_message(),
1210 HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
1211 }
1212
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)1213 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
1214 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1215 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1216 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1217 {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
1218 auto inferred_status = ShapeInference::InferReduceShape(
1219 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1220 EXPECT_FALSE(inferred_status.ok());
1221 EXPECT_THAT(inferred_status.status().error_message(),
1222 HasSubstr("accumulator shape at index 0 differs from the "
1223 "init_value shape: s32[] vs f32[]"));
1224 }
1225
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)1226 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
1227 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1228 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1229 auto inferred_status = ShapeInference::InferReduceShape(
1230 {&arg_shape, &f32_},
1231 /*dimensions_to_reduce=*/{3, 4}, to_apply);
1232 EXPECT_FALSE(inferred_status.ok());
1233 EXPECT_THAT(inferred_status.status().error_message(),
1234 HasSubstr("out-of-bounds dimension"));
1235 }
1236
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)1237 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
1238 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
1239 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1240 auto inferred_status =
1241 ShapeInference::InferReduceShape({&arg_shape, &f32_},
1242 /*dimensions_to_reduce=*/{0}, to_apply);
1243 EXPECT_FALSE(inferred_status.ok());
1244 EXPECT_THAT(inferred_status.status().error_message(),
1245 HasSubstr("take 2 parameters"));
1246 }
1247
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)1248 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
1249 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
1250 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1251 auto inferred_status =
1252 ShapeInference::InferReduceShape({&arg_shape, &f32_},
1253 /*dimensions_to_reduce=*/{0}, to_apply);
1254 EXPECT_FALSE(inferred_status.ok());
1255 EXPECT_THAT(inferred_status.status().error_message(),
1256 HasSubstr("0-th parameter shape differs"));
1257 }
1258
TEST_F(ReduceShapeInferenceTest,ReduceWithRepeatedReduceDimension)1259 TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
1260 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1261 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1262 auto inferred_status = ShapeInference::InferReduceShape(
1263 {&arg_shape, &f32_},
1264 /*dimensions_to_reduce=*/{0, 0}, to_apply);
1265 EXPECT_FALSE(inferred_status.ok());
1266 EXPECT_THAT(inferred_status.status().error_message(),
1267 HasSubstr("Duplicate reduction dimension: 0"));
1268 }
1269
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)1270 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
1271 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1272 auto inferred_status =
1273 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
1274 ASSERT_IS_OK(inferred_status.status());
1275 Shape inferred = inferred_status.ValueOrDie();
1276 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
1277 }
1278
TEST_F(ShapeInferenceTest,InferSliceWithDynamicDimensions)1279 TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
1280 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
1281 auto inferred_status =
1282 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
1283 ASSERT_IS_OK(inferred_status.status());
1284 Shape inferred = inferred_status.ValueOrDie();
1285 ASSERT_TRUE(ShapeUtil::Equal(
1286 ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
1287 }
1288
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)1289 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
1290 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1291 auto inferred_status =
1292 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
1293 ASSERT_IS_OK(inferred_status.status());
1294 Shape inferred = inferred_status.ValueOrDie();
1295 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
1296 }
1297
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)1298 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
1299 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1300 auto inferred_status =
1301 ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
1302 ASSERT_IS_OK(inferred_status.status());
1303 Shape inferred = inferred_status.ValueOrDie();
1304 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
1305 }
1306
TEST_F(ShapeInferenceTest,InferInvalidStride)1307 TEST_F(ShapeInferenceTest, InferInvalidStride) {
1308 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1309 auto inferred_status =
1310 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
1311 ASSERT_FALSE(inferred_status.ok());
1312 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1313 inferred_status.status().code());
1314 }
1315
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)1316 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
1317 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1318 auto inferred_status =
1319 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
1320 ASSERT_FALSE(inferred_status.ok());
1321 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1322 inferred_status.status().code());
1323 }
1324
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)1325 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
1326 Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
1327 auto inferred_status =
1328 ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
1329 ASSERT_TRUE(inferred_status.ok());
1330 Shape inferred = inferred_status.ValueOrDie();
1331 ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
1332 }
1333
TEST_F(ShapeInferenceTest,InferConstIndexShape)1334 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
1335 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1336 auto inferred0_status =
1337 ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
1338 auto inferred1_status =
1339 ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
1340 ASSERT_IS_OK(inferred0_status.status());
1341 ASSERT_IS_OK(inferred1_status.status());
1342 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
1343 ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
1344 }
1345
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)1346 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
1347 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1348 auto inferredNegative_status =
1349 ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
1350 auto inferred2_status =
1351 ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
1352 ASSERT_FALSE(inferredNegative_status.ok());
1353 ASSERT_FALSE(inferred2_status.ok());
1354 EXPECT_THAT(inferredNegative_status.status().error_message(),
1355 HasSubstr("attempt to index out of tuple bounds"));
1356 EXPECT_THAT(inferred2_status.status().error_message(),
1357 HasSubstr("attempt to index out of tuple bounds"));
1358 }
1359
TEST_F(ShapeInferenceTest,InferPowShape)1360 TEST_F(ShapeInferenceTest, InferPowShape) {
1361 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1362 auto inferred_status = ShapeInference::InferBinaryOpShape(
1363 HloOpcode::kPower, ten_floats, f32_, {});
1364 ASSERT_IS_OK(inferred_status.status());
1365 ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
1366 }
1367
TEST_F(ShapeInferenceTest,InferCompareShape)1368 TEST_F(ShapeInferenceTest, InferCompareShape) {
1369 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1370 auto inferred_status = ShapeInference::InferBinaryOpShape(
1371 HloOpcode::kCompare, ten_floats, f32_, {});
1372 ASSERT_IS_OK(inferred_status.status());
1373 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
1374 inferred_status.ValueOrDie()));
1375 }
1376
TEST_F(ShapeInferenceTest,InferReshapeDegenerateCombine)1377 TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
1378 // [1, <=1]
1379 // | reshape
1380 // [<=1]
1381 //
1382 // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1383 auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
1384 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
1385 /*inferred_dimension=*/-1);
1386 ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
1387 }
1388
TEST_F(ShapeInferenceTest,InferReshapeSplit)1389 TEST_F(ShapeInferenceTest, InferReshapeSplit) {
1390 // [<=10]
1391 // | reshape
1392 // [1, 10]
1393 //
1394 // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1395 auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
1396 auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
1397 /*inferred_dimension=*/0);
1398 ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
1399 status.ValueOrDie());
1400 }
1401
TEST_F(ShapeInferenceTest,InferReshapeCombine)1402 TEST_F(ShapeInferenceTest, InferReshapeCombine) {
1403 // [6, <=10]
1404 // | reshape
1405 // [<=60]
1406 auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1407 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
1408 /*inferred_dimension=*/-11);
1409 ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
1410 }
1411
TEST_F(ShapeInferenceTest,UnchangedDimension)1412 TEST_F(ShapeInferenceTest, UnchangedDimension) {
1413 // [6, <=10]
1414 // | reshape
1415 // [2, 3, <=10]
1416 auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1417 auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
1418 /*inferred_dimension=*/-11);
1419 ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
1420 status.ValueOrDie());
1421 }
1422
TEST_F(ShapeInferenceTest,InferDynamicBroadcast)1423 TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
1424 // CHECK:
1425 // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
1426
1427 auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
1428 auto inferred_status =
1429 ShapeInference::InferBroadcastShape(operand_shape, {15});
1430 ASSERT_IS_OK(inferred_status.status());
1431 Shape inferred = inferred_status.ValueOrDie();
1432 ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
1433 }
1434
TEST_F(ShapeInferenceTest,BroadcastScalar)1435 TEST_F(ShapeInferenceTest, BroadcastScalar) {
1436 for (auto element_type : {F32, U32, S8}) {
1437 const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
1438 { // no-op scalar broadcast
1439 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
1440 ASSERT_IS_OK(status.status());
1441 ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
1442 }
1443 const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
1444 { // scalar -> 1d broadcast
1445 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
1446 ASSERT_IS_OK(status.status());
1447 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1448 }
1449 { // no-op 1d broadcast
1450 auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
1451 ASSERT_IS_OK(status.status());
1452 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1453 }
1454 const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
1455 { // scalar -> 2d broadcast
1456 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
1457 ASSERT_IS_OK(status.status());
1458 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1459 }
1460 { // 1d -> 2d broadcast
1461 auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
1462 ASSERT_IS_OK(status.status());
1463 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1464 }
1465 }
1466 }
1467
1468 // scalar <dot> vector: ok
TEST_F(ShapeInferenceTest,ScalarDotVector)1469 TEST_F(ShapeInferenceTest, ScalarDotVector) {
1470 DotDimensionNumbers dot_dnums;
1471 auto inferred_status = ShapeInference::InferDotOpShape(
1472 f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
1473 EXPECT_TRUE(inferred_status.ok());
1474 EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
1475 }
1476
1477 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)1478 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
1479 DotDimensionNumbers dot_dnums;
1480 dot_dnums.add_lhs_contracting_dimensions(1);
1481 dot_dnums.add_rhs_contracting_dimensions(0);
1482 auto inferred_status = ShapeInference::InferDotOpShape(
1483 ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
1484 /*preferred_element_type=*/absl::nullopt);
1485 EXPECT_TRUE(inferred_status.ok());
1486 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1487 ShapeUtil::MakeShape(F32, {32, 32, 64})));
1488 }
1489
1490 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)1491 TEST_F(ShapeInferenceTest, VectorDotVector) {
1492 DotDimensionNumbers dot_dnums;
1493 dot_dnums.add_lhs_contracting_dimensions(0);
1494 dot_dnums.add_rhs_contracting_dimensions(0);
1495 auto inferred_status =
1496 ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
1497 /*preferred_element_type=*/absl::nullopt);
1498 ASSERT_IS_OK(inferred_status.status());
1499 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
1500 auto inferred_status_mismatch =
1501 ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
1502 /*preferred_element_type=*/absl::nullopt);
1503 ASSERT_FALSE(inferred_status_mismatch.ok());
1504 }
1505
1506 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1507 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1508 DotDimensionNumbers dot_dnums;
1509 dot_dnums.add_lhs_contracting_dimensions(1);
1510 dot_dnums.add_rhs_contracting_dimensions(0);
1511 auto inferred_status =
1512 ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
1513 /*preferred_element_type=*/absl::nullopt);
1514 ASSERT_IS_OK(inferred_status.status());
1515 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1516 auto inferred_status_mismatch =
1517 ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
1518 /*preferred_element_type=*/absl::nullopt);
1519 ASSERT_FALSE(inferred_status_mismatch.ok());
1520 }
1521
1522 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1523 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1524 DotDimensionNumbers dot_dnums;
1525 dot_dnums.add_lhs_contracting_dimensions(0);
1526 dot_dnums.add_rhs_contracting_dimensions(0);
1527 auto inferred_status =
1528 ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
1529 /*preferred_element_type=*/absl::nullopt);
1530 ASSERT_IS_OK(inferred_status.status());
1531 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1532 auto inferred_status_mismatch =
1533 ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
1534 /*preferred_element_type=*/absl::nullopt);
1535 ASSERT_FALSE(inferred_status_mismatch.ok());
1536 }
1537
1538 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1539 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1540 DotDimensionNumbers dot_dnums;
1541 dot_dnums.add_lhs_contracting_dimensions(1);
1542 dot_dnums.add_rhs_contracting_dimensions(0);
1543 auto inferred_status_match =
1544 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
1545 /*preferred_element_type=*/absl::nullopt);
1546 ASSERT_IS_OK(inferred_status_match.status());
1547 ASSERT_TRUE(
1548 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1549 << "inferred: "
1550 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1551 << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1552 auto inferred_status_mismatch =
1553 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
1554 /*preferred_element_type=*/absl::nullopt);
1555 ASSERT_FALSE(inferred_status_mismatch.ok());
1556 }
1557
1558 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1559 TEST_F(ShapeInferenceTest, DotGeneral) {
1560 Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1561 Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1562 Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1563
1564 DotDimensionNumbers dot_dnums;
1565 dot_dnums.add_lhs_contracting_dimensions(3);
1566 dot_dnums.add_lhs_batch_dimensions(0);
1567 dot_dnums.add_lhs_batch_dimensions(1);
1568
1569 dot_dnums.add_rhs_contracting_dimensions(2);
1570 dot_dnums.add_rhs_batch_dimensions(0);
1571 dot_dnums.add_rhs_batch_dimensions(1);
1572
1573 auto inferred_status_match =
1574 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1575 /*preferred_element_type=*/absl::nullopt);
1576 ASSERT_IS_OK(inferred_status_match.status());
1577 ASSERT_TRUE(
1578 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1579 << "inferred: "
1580 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1581 << " expected: " << ShapeUtil::HumanString(output_shape);
1582 }
1583
1584 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1585 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1586 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1587 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1588
1589 DotDimensionNumbers dot_dnums;
1590 dot_dnums.add_lhs_contracting_dimensions(2);
1591 dot_dnums.add_lhs_contracting_dimensions(3);
1592 dot_dnums.add_lhs_batch_dimensions(0);
1593
1594 dot_dnums.add_rhs_contracting_dimensions(1);
1595 dot_dnums.add_rhs_batch_dimensions(0);
1596
1597 auto inferred_status =
1598 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1599 /*preferred_element_type=*/absl::nullopt);
1600 ASSERT_FALSE(inferred_status.ok());
1601 ASSERT_THAT(inferred_status.status().error_message(),
1602 HasSubstr("Must specify the same number of contracting "
1603 "dimensions for lhs and rhs."));
1604 }
1605
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1606 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1607 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1608 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1609 Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1610
1611 DotDimensionNumbers dot_dnums;
1612 dot_dnums.add_lhs_contracting_dimensions(2);
1613 dot_dnums.add_lhs_contracting_dimensions(3);
1614 dot_dnums.add_lhs_batch_dimensions(0);
1615
1616 dot_dnums.add_rhs_contracting_dimensions(1);
1617 dot_dnums.add_rhs_contracting_dimensions(2);
1618 dot_dnums.add_rhs_batch_dimensions(0);
1619
1620 auto inferred_status =
1621 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1622 /*preferred_element_type=*/absl::nullopt);
1623 EXPECT_TRUE(inferred_status.ok());
1624 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1625 }
1626
TEST_F(ShapeInferenceTest,ErrorSetDimensionSize)1627 TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
1628 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1629 Shape val_shape = ShapeUtil::MakeShape(S32, {1});
1630 auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1631 arg_shape, val_shape, /*dimension=*/0);
1632
1633 EXPECT_FALSE(inferred_status.ok());
1634 EXPECT_THAT(inferred_status.status().error_message(),
1635 HasSubstr("value has to be S32 scalar"));
1636 }
1637
TEST_F(ShapeInferenceTest,ErrorSetDimensionSizeWrongType)1638 TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
1639 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1640 Shape val_shape = ShapeUtil::MakeShape(U32, {});
1641 auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1642 arg_shape, val_shape, /*dimension=*/0);
1643
1644 EXPECT_FALSE(inferred_status.ok());
1645 EXPECT_THAT(inferred_status.status().error_message(),
1646 HasSubstr("value has to be S32 scalar"));
1647 }
1648
1649 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimSizesFails)1650 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
1651 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1652 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1653
1654 DotDimensionNumbers dot_dnums;
1655 dot_dnums.add_lhs_contracting_dimensions(2);
1656 dot_dnums.add_lhs_batch_dimensions(0);
1657
1658 dot_dnums.add_rhs_contracting_dimensions(1);
1659 dot_dnums.add_rhs_batch_dimensions(0);
1660
1661 auto inferred_status =
1662 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1663 /*preferred_element_type=*/absl::nullopt);
1664 ASSERT_FALSE(inferred_status.ok());
1665 ASSERT_THAT(inferred_status.status().error_message(),
1666 HasSubstr("Batch dimension sizes must match"));
1667 }
1668
1669 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimNumbersPasses)1670 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
1671 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1672 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1673
1674 DotDimensionNumbers dot_dnums;
1675 dot_dnums.add_lhs_contracting_dimensions(2);
1676 dot_dnums.add_lhs_batch_dimensions(0);
1677
1678 dot_dnums.add_rhs_contracting_dimensions(0);
1679 dot_dnums.add_rhs_batch_dimensions(1);
1680
1681 auto inferred_status =
1682 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1683 /*preferred_element_type=*/absl::nullopt);
1684 ASSERT_TRUE(inferred_status.ok());
1685 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1686 ShapeUtil::MakeShape(F32, {2, 11, 14})));
1687 }
1688
1689 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1690 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1691 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1692 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1693
1694 DotDimensionNumbers dot_dnums;
1695 dot_dnums.add_lhs_contracting_dimensions(3);
1696 dot_dnums.add_lhs_batch_dimensions(0);
1697
1698 dot_dnums.add_rhs_contracting_dimensions(0);
1699 dot_dnums.add_rhs_batch_dimensions(1);
1700
1701 auto inferred_status =
1702 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1703 /*preferred_element_type=*/absl::nullopt);
1704 ASSERT_FALSE(inferred_status.ok());
1705 ASSERT_THAT(inferred_status.status().error_message(),
1706 HasSubstr("A dimension number is out of range"));
1707 }
1708
1709 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1710 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1711 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1712 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1713
1714 DotDimensionNumbers dot_dnums;
1715 dot_dnums.add_lhs_contracting_dimensions(0);
1716 dot_dnums.add_lhs_batch_dimensions(0);
1717
1718 dot_dnums.add_rhs_contracting_dimensions(0);
1719 dot_dnums.add_rhs_batch_dimensions(1);
1720
1721 auto inferred_status =
1722 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1723 /*preferred_element_type=*/absl::nullopt);
1724 ASSERT_FALSE(inferred_status.ok());
1725 ASSERT_THAT(inferred_status.status().error_message(),
1726 HasSubstr("A dimension number is not unique"));
1727 }
1728
TEST_F(ShapeInferenceTest,DotWithIntegralPreferredElementType)1729 TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
1730 DotDimensionNumbers dot_dnums;
1731 dot_dnums.add_lhs_contracting_dimensions(1);
1732 dot_dnums.add_rhs_contracting_dimensions(0);
1733 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1734 ShapeInference::InferDotOpShape(
1735 ShapeUtil::MakeShape(S8, {32, 32}),
1736 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1737 /*preferred_element_type=*/S32));
1738 EXPECT_TRUE(
1739 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1740 }
1741
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeSameAsInferredType)1742 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
1743 DotDimensionNumbers dot_dnums;
1744 dot_dnums.add_lhs_contracting_dimensions(1);
1745 dot_dnums.add_rhs_contracting_dimensions(0);
1746 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1747 ShapeInference::InferDotOpShape(
1748 ShapeUtil::MakeShape(BF16, {32, 32}),
1749 ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1750 /*preferred_element_type=*/F32));
1751 EXPECT_TRUE(
1752 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1753 }
1754
TEST_F(ShapeInferenceTest,FloatingPointDotWithNarrowerPreferredElementType)1755 TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
1756 DotDimensionNumbers dot_dnums;
1757 dot_dnums.add_lhs_contracting_dimensions(1);
1758 dot_dnums.add_rhs_contracting_dimensions(0);
1759 TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1760 ShapeInference::InferDotOpShape(
1761 ShapeUtil::MakeShape(BF16, {32, 32}),
1762 ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1763 /*preferred_element_type=*/BF16));
1764 EXPECT_TRUE(
1765 ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
1766 }
1767
TEST_F(ShapeInferenceTest,FloatingPointDotWithInvalidPreferredElementType)1768 TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
1769 DotDimensionNumbers dot_dnums;
1770 dot_dnums.add_lhs_contracting_dimensions(1);
1771 dot_dnums.add_rhs_contracting_dimensions(0);
1772 auto inferred_status = ShapeInference::InferDotOpShape(
1773 ShapeUtil::MakeShape(BF16, {32, 32}),
1774 ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
1775 /*preferred_element_type=*/S32)
1776 .status();
1777 ASSERT_FALSE(inferred_status.ok());
1778 ASSERT_THAT(inferred_status.error_message(),
1779 HasSubstr("must both be integral or both be floating point"));
1780 }
1781
TEST_F(ShapeInferenceTest,IntegralDotWithFloatingPointPreferredElementType)1782 TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
1783 DotDimensionNumbers dot_dnums;
1784 dot_dnums.add_lhs_contracting_dimensions(1);
1785 dot_dnums.add_rhs_contracting_dimensions(0);
1786 auto inferred_status = ShapeInference::InferDotOpShape(
1787 ShapeUtil::MakeShape(S8, {32, 32}),
1788 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1789 /*preferred_element_type=*/F32)
1790 .status();
1791 ASSERT_FALSE(inferred_status.ok());
1792 ASSERT_THAT(inferred_status.error_message(),
1793 HasSubstr("must both be integral or both be floating point"));
1794 }
1795
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeWithDifferentSignedness)1796 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
1797 DotDimensionNumbers dot_dnums;
1798 dot_dnums.add_lhs_contracting_dimensions(1);
1799 dot_dnums.add_rhs_contracting_dimensions(0);
1800 auto inferred_status = ShapeInference::InferDotOpShape(
1801 ShapeUtil::MakeShape(S8, {32, 32}),
1802 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1803 /*preferred_element_type=*/U32)
1804 .status();
1805 ASSERT_FALSE(inferred_status.ok());
1806 ASSERT_THAT(inferred_status.error_message(),
1807 HasSubstr("must have the same signedness as the original type"));
1808 }
1809
TEST_F(ShapeInferenceTest,DotWithNarrowerPreferredElementType)1810 TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
1811 DotDimensionNumbers dot_dnums;
1812 dot_dnums.add_lhs_contracting_dimensions(1);
1813 dot_dnums.add_rhs_contracting_dimensions(0);
1814 auto inferred_status = ShapeInference::InferDotOpShape(
1815 ShapeUtil::MakeShape(S8, {32, 32}),
1816 ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1817 /*preferred_element_type=*/S8)
1818 .status();
1819 ASSERT_FALSE(inferred_status.ok());
1820 ASSERT_THAT(inferred_status.error_message(),
1821 HasSubstr("must not be narrower than the original type"));
1822 }
1823
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1824 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1825 // Test variations of broadcasting a vector for a binary add with a
1826 // matrix.
1827 const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1828 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1829 const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1830
1831 auto inferred_status_match =
1832 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1833 ASSERT_IS_OK(inferred_status_match.status());
1834 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1835
1836 auto inferred_status_mismatch =
1837 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1838 ASSERT_FALSE(inferred_status_mismatch.ok());
1839
1840 inferred_status_match =
1841 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1842 ASSERT_IS_OK(inferred_status_match.status());
1843 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1844
1845 inferred_status_mismatch =
1846 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1847 ASSERT_FALSE(inferred_status_mismatch.ok());
1848 }
1849
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1850 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1851 // Test variations of broadcasting a matrix for a binary add with a cube.
1852 const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1853 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1854 const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1855 const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1856
1857 auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1858 HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1859 ASSERT_IS_OK(inferred_status_match.status());
1860 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1861
1862 inferred_status_match = ShapeInference::InferBinaryOpShape(
1863 HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1864 ASSERT_IS_OK(inferred_status_match.status());
1865 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1866
1867 inferred_status_match = ShapeInference::InferBinaryOpShape(
1868 HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1869 ASSERT_IS_OK(inferred_status_match.status());
1870 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1871 }
1872
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1873 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1874 // Test various errors with the broadcast argument.
1875 const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1876 const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1877 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1878 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1879 const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1880
1881 // "magical" broadcast rejected
1882 auto inferred_status_error1 =
1883 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1884 ASSERT_FALSE(inferred_status_error1.ok());
1885 ASSERT_THAT(inferred_status_error1.status().error_message(),
1886 HasSubstr("Automatic"));
1887
1888 // broadcast_dimension out of bounds for tensor's rank
1889 auto inferred_status_error2 =
1890 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1891 ASSERT_FALSE(inferred_status_error2.ok());
1892 ASSERT_THAT(inferred_status_error2.status().error_message(),
1893 ContainsRegex("Broadcast dimension number .* too large"));
1894
1895 // broadcast_dimension doesn't match corresponding dimension
1896 auto inferred_status_error3 =
1897 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1898 ASSERT_FALSE(inferred_status_error3.ok());
1899 ASSERT_THAT(inferred_status_error3.status().error_message(),
1900 HasSubstr("Broadcast dimension 0 mismatch"));
1901
1902 // broadcast_dimensions list too long
1903 auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1904 HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1905 ASSERT_FALSE(inferred_status_error4.ok());
1906 ASSERT_THAT(inferred_status_error4.status().error_message(),
1907 HasSubstr("broadcast_dimensions has to match"));
1908
1909 // there's a dimension above the rank of the tensor
1910 auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1911 HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1912 ASSERT_FALSE(inferred_status_error5.ok());
1913 ASSERT_THAT(inferred_status_error5.status().error_message(),
1914 ContainsRegex("dimension number .* too large"));
1915
1916 // broadcasting dimensions don't match in this order
1917 auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1918 HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1919 ASSERT_FALSE(inferred_status_error6.ok());
1920 ASSERT_THAT(inferred_status_error6.status().error_message(),
1921 HasSubstr("dimension 0 mismatch"));
1922
1923 // The following two tests make sure that broadcasting dimensions are listed
1924 // in a proper (strictly increasing) order, even if the lower-rank array
1925 // matches the higher-rank array in many different ways.
1926 auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1927 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1928 ASSERT_FALSE(inferred_status_error7.ok());
1929 ASSERT_THAT(inferred_status_error7.status().error_message(),
1930 HasSubstr("dimensions order is wrong"));
1931
1932 auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1933 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1934 ASSERT_FALSE(inferred_status_error8.ok());
1935 ASSERT_THAT(inferred_status_error8.status().error_message(),
1936 HasSubstr("dimensions order is wrong"));
1937 }
1938
1939 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1940 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1941 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1942 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1943 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1944 auto inferred_status =
1945 ShapeInference::InferWhileShape(cond, body, result_shape);
1946 ASSERT_IS_OK(inferred_status.status());
1947 Shape inferred = inferred_status.ValueOrDie();
1948 ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1949 }
1950
1951 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1952 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1953 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1954 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1955 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1956
1957 auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1958 auto inferred_status_error1 =
1959 ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1960 ASSERT_FALSE(inferred_status_error1.ok());
1961 ASSERT_THAT(inferred_status_error1.status().error_message(),
1962 HasSubstr("Condition must take 1 arguments"));
1963
1964 auto bad_shape_2 =
1965 ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1966 auto inferred_status_error2 =
1967 ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
1968 ASSERT_FALSE(inferred_status_error2.ok());
1969 ASSERT_THAT(inferred_status_error2.status().error_message(),
1970 HasSubstr("Body must take 1 arguments"));
1971
1972 auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
1973 auto inferred_status_error3 =
1974 ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
1975 ASSERT_FALSE(inferred_status_error3.ok());
1976 ASSERT_THAT(inferred_status_error3.status().error_message(),
1977 HasSubstr("Condition must return a boolean"));
1978
1979 auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
1980 auto inferred_status_error4 =
1981 ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
1982 ASSERT_FALSE(inferred_status_error4.ok());
1983 ASSERT_THAT(inferred_status_error4.status().error_message(),
1984 HasSubstr("parameter of condition and body"));
1985 }
1986
1987 // Tests for the concatenate instruction with dynamic shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithDynamicShapes)1988 TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
1989 auto dynamic_shape_1 =
1990 ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
1991 auto dynamic_shape_2 =
1992 ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
1993 auto inferred_status = ShapeInference::InferConcatOpShape(
1994 {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
1995 ASSERT_IS_OK(inferred_status.status());
1996 Shape inferred = inferred_status.ValueOrDie();
1997 ASSERT_TRUE(ShapeUtil::Equal(
1998 ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
1999 }
2000
2001 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)2002 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
2003 auto inferred_status_1 = ShapeInference::InferConcatOpShape(
2004 {&vector_32_, &vector_64_}, /*dimension=*/0);
2005 ASSERT_IS_OK(inferred_status_1.status());
2006 Shape inferred_1 = inferred_status_1.ValueOrDie();
2007 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
2008
2009 auto inferred_status_2 = ShapeInference::InferConcatOpShape(
2010 {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
2011 ASSERT_IS_OK(inferred_status_2.status());
2012 Shape inferred_2 = inferred_status_2.ValueOrDie();
2013 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
2014
2015 auto inferred_status_3 = ShapeInference::InferConcatOpShape(
2016 {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
2017 ASSERT_IS_OK(inferred_status_3.status());
2018 Shape inferred_3 = inferred_status_3.ValueOrDie();
2019 ASSERT_TRUE(
2020 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
2021 }
2022
2023 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)2024 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
2025 auto inferred_status_error1 =
2026 ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
2027 ASSERT_FALSE(inferred_status_error1.ok());
2028 ASSERT_THAT(inferred_status_error1.status().error_message(),
2029 HasSubstr("Concatenate expects at least one argument"));
2030
2031 auto inferred_status_error2 =
2032 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
2033 ASSERT_FALSE(inferred_status_error2.ok());
2034 ASSERT_THAT(inferred_status_error2.status().error_message(),
2035 HasSubstr("dimension out of bounds: -1"));
2036
2037 auto inferred_status_error3 =
2038 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
2039 ASSERT_FALSE(inferred_status_error3.ok());
2040 ASSERT_THAT(inferred_status_error3.status().error_message(),
2041 HasSubstr("dimension out of bounds: 1"));
2042
2043 Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
2044 auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
2045 {&vector_32_, &tuple}, /*dimension=*/0);
2046 ASSERT_FALSE(inferred_status_error4.ok());
2047 ASSERT_THAT(
2048 inferred_status_error4.status().error_message(),
2049 HasSubstr("Expected array argument for operand of concatenation"));
2050
2051 const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
2052 auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
2053 {&vector_32_, &vector_s32}, /*dimension=*/0);
2054 ASSERT_FALSE(inferred_status_error5.ok());
2055 ASSERT_THAT(inferred_status_error5.status().error_message(),
2056 HasSubstr("concatenate arrays with different element types"));
2057
2058 auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
2059 {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
2060 ASSERT_FALSE(inferred_status_error6.ok());
2061 ASSERT_THAT(inferred_status_error6.status().error_message(),
2062 HasSubstr("concatenate arrays that differ in "
2063 "dimensions other than the one being "
2064 "concatenated"));
2065 }
2066
TEST_F(ShapeInferenceTest,Pad)2067 TEST_F(ShapeInferenceTest, Pad) {
2068 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2069 Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
2070 // Padding for dimension 0: {low: 0, high: 2, interior: 3}
2071 // Padding for dimension 1: {low: 1, high: 5, interior: 0}
2072 PaddingConfig padding_config;
2073 auto dimension0 = padding_config.add_dimensions();
2074 dimension0->set_edge_padding_low(0);
2075 dimension0->set_edge_padding_high(2);
2076 dimension0->set_interior_padding(3);
2077 auto dimension1 = padding_config.add_dimensions();
2078 dimension1->set_edge_padding_low(1);
2079 dimension1->set_edge_padding_high(5);
2080 dimension1->set_interior_padding(0);
2081
2082 auto inferred_status = ShapeInference::InferPadShape(
2083 input_shape, padding_value_shape, padding_config);
2084 ASSERT_IS_OK(inferred_status.status());
2085 Shape inferred_shape = inferred_status.ValueOrDie();
2086 ASSERT_TRUE(
2087 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
2088
2089 dimension1->set_edge_padding_low(-20);
2090 dimension1->set_edge_padding_high(-10);
2091 auto negative_dimension_size = ShapeInference::InferPadShape(
2092 input_shape, padding_value_shape, padding_config);
2093 ASSERT_FALSE(negative_dimension_size.ok());
2094 ASSERT_THAT(negative_dimension_size.status().error_message(),
2095 HasSubstr("negative size for dimension 1"));
2096 }
2097
TEST_F(ShapeInferenceTest,Reverse)2098 TEST_F(ShapeInferenceTest, Reverse) {
2099 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2100
2101 auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
2102 ASSERT_IS_OK(inferred_status.status());
2103 Shape inferred_shape = inferred_status.ValueOrDie();
2104 ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
2105 }
2106
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)2107 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
2108 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2109
2110 auto inferred_status_error0 =
2111 ShapeInference::InferReverseShape(input_shape, {0, 2});
2112 ASSERT_FALSE(inferred_status_error0.ok());
2113 ASSERT_THAT(inferred_status_error0.status().error_message(),
2114 HasSubstr("out-of-bounds"));
2115
2116 auto inferred_status_error1 =
2117 ShapeInference::InferReverseShape(input_shape, {0, -1});
2118 ASSERT_FALSE(inferred_status_error1.ok());
2119 ASSERT_THAT(inferred_status_error1.status().error_message(),
2120 HasSubstr("out-of-bounds"));
2121
2122 auto inferred_status_error2 =
2123 ShapeInference::InferReverseShape(input_shape, {0, 0});
2124 ASSERT_FALSE(inferred_status_error2.ok());
2125 ASSERT_THAT(inferred_status_error2.status().error_message(),
2126 HasSubstr("duplicated"));
2127
2128 Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
2129 auto inferred_status_error3 =
2130 ShapeInference::InferReverseShape(tuple_shape, {0});
2131 ASSERT_FALSE(inferred_status_error3.ok());
2132 ASSERT_THAT(inferred_status_error3.status().error_message(),
2133 HasSubstr("Expected array argument"));
2134 }
2135
TEST_F(ShapeInferenceTest,Call)2136 TEST_F(ShapeInferenceTest, Call) {
2137 auto inferred_status0 =
2138 ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
2139 EXPECT_IS_OK(inferred_status0.status());
2140 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2141
2142 auto inferred_status1 = ShapeInference::InferCallShape(
2143 {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
2144 ShapeUtil::MakeProgramShape(
2145 {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
2146 EXPECT_IS_OK(inferred_status1.status());
2147 EXPECT_TRUE(
2148 ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
2149
2150 auto inferred_status_error0 = ShapeInference::InferCallShape(
2151 {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
2152 EXPECT_FALSE(inferred_status_error0.ok());
2153 EXPECT_THAT(inferred_status_error0.status().error_message(),
2154 HasSubstr("arity must match"));
2155
2156 auto inferred_status_error1 = ShapeInference::InferCallShape(
2157 {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
2158 EXPECT_FALSE(inferred_status_error1.ok());
2159 EXPECT_THAT(inferred_status_error1.status().error_message(),
2160 HasSubstr("arity must match"));
2161
2162 auto inferred_status_error2 = ShapeInference::InferCallShape(
2163 {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
2164 EXPECT_FALSE(inferred_status_error2.ok());
2165 EXPECT_THAT(inferred_status_error2.status().error_message(),
2166 HasSubstr("parameter must match argument"));
2167 }
2168
TEST_F(ShapeInferenceTest,Transpose)2169 TEST_F(ShapeInferenceTest, Transpose) {
2170 Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
2171 auto inferred_shape_and_status =
2172 ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
2173 EXPECT_IS_OK(inferred_shape_and_status);
2174 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2175 EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
2176 ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
2177 }
2178
TEST_F(ShapeInferenceTest,Rank1Transpose)2179 TEST_F(ShapeInferenceTest, Rank1Transpose) {
2180 Shape a_shape = ShapeUtil::MakeShape(F32, {5});
2181 auto inferred_shape_and_status =
2182 ShapeInference::InferTransposeShape(a_shape, {0});
2183 EXPECT_IS_OK(inferred_shape_and_status);
2184 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2185 EXPECT_TRUE(
2186 ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
2187 }
2188
TEST_F(ShapeInferenceTest,ConditionalPred)2189 TEST_F(ShapeInferenceTest, ConditionalPred) {
2190 auto inferred_status0 = ShapeInference::InferConditionalShape(
2191 pred_,
2192 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2193 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2194 {vector_32_, vector_64_});
2195 EXPECT_IS_OK(inferred_status0.status());
2196 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2197
2198 auto inferred_status1 = ShapeInference::InferConditionalShape(
2199 pred_,
2200 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2201 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
2202 {matrix_32_48_, vector_32_});
2203 EXPECT_IS_OK(inferred_status1.status());
2204 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2205
2206 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2207 auto inferred_status2 = ShapeInference::InferConditionalShape(
2208 pred_,
2209 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2210 ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2211 {matrix_32_48_, tuple_f32_v32});
2212 EXPECT_IS_OK(inferred_status2.status());
2213 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2214
2215 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2216 f32_,
2217 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2218 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2219 {vector_32_, vector_64_});
2220 EXPECT_FALSE(inferred_status_error0.ok());
2221 EXPECT_THAT(inferred_status_error0.status().error_message(),
2222 HasSubstr("must be bool or int32"));
2223
2224 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2225 pred_,
2226 {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2227 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2228 {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
2229 EXPECT_FALSE(inferred_status_error1.ok());
2230 EXPECT_THAT(inferred_status_error1.status().error_message(),
2231 HasSubstr("branch computation 0 must take 1 argument"));
2232
2233 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2234 pred_,
2235 {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2236 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2237 {vector_32_, vector_64_});
2238 EXPECT_FALSE(inferred_status_error2.ok());
2239 EXPECT_THAT(inferred_status_error2.status().error_message(),
2240 HasSubstr("branch operand 0 must match the shape of the only "
2241 "parameter of branch computation 0"));
2242
2243 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2244 pred_,
2245 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2246 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
2247 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
2248 EXPECT_FALSE(inferred_status_error3.ok());
2249 EXPECT_THAT(inferred_status_error3.status().error_message(),
2250 HasSubstr("branch computation 1 must take 1 argument"));
2251
2252 auto inferred_status_error4 = ShapeInference::InferConditionalShape(
2253 pred_,
2254 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2255 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2256 {vector_32_, vector_64_});
2257 EXPECT_FALSE(inferred_status_error4.ok());
2258 EXPECT_THAT(inferred_status_error4.status().error_message(),
2259 HasSubstr("branch operand 1 must match the shape of the only "
2260 "parameter of branch computation 1"));
2261
2262 auto inferred_status_error5 = ShapeInference::InferConditionalShape(
2263 pred_,
2264 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2265 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2266 {vector_32_, vector_64_});
2267 EXPECT_FALSE(inferred_status_error5.ok());
2268 EXPECT_THAT(inferred_status_error5.status().error_message(),
2269 HasSubstr("the result of branch 0 computation and branch 1 "
2270 "computation must have the same shape"));
2271 }
2272
TEST_F(ShapeInferenceTest,ConditionalIndexed)2273 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
2274 auto r0s32 = ShapeUtil::MakeShape(S32, {});
2275 auto inferred_status0 = ShapeInference::InferConditionalShape(
2276 r0s32,
2277 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2278 ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2279 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2280 {vector_32_, vector_64_, vector_64_});
2281 EXPECT_IS_OK(inferred_status0.status());
2282 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2283
2284 auto inferred_status1 = ShapeInference::InferConditionalShape(
2285 r0s32,
2286 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2287 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
2288 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
2289 {matrix_32_48_, vector_32_, matrix_32_48_});
2290 EXPECT_IS_OK(inferred_status1.status());
2291 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2292
2293 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2294 auto inferred_status2 = ShapeInference::InferConditionalShape(
2295 r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2296 {tuple_f32_v32});
2297 EXPECT_IS_OK(inferred_status2.status());
2298 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2299
2300 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2301 pred_,
2302 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2303 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2304 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2305 {vector_32_, vector_32_, vector_64_});
2306 EXPECT_FALSE(inferred_status_error0.ok());
2307 EXPECT_THAT(inferred_status_error0.status().error_message(),
2308 HasSubstr("2 == branch_computations.size()"));
2309
2310 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2311 r0s32,
2312 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2313 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2314 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2315 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
2316 matrix_32_48_});
2317 EXPECT_FALSE(inferred_status_error1.ok());
2318 EXPECT_THAT(inferred_status_error1.status().error_message(),
2319 HasSubstr("branch computation 1 must take 1 argument"));
2320
2321 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2322 r0s32,
2323 {ShapeUtil::MakeProgramShape({r0s32}, f32_),
2324 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2325 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2326 {r0s32, vector_32_, vector_64_});
2327 EXPECT_FALSE(inferred_status_error2.ok());
2328 EXPECT_THAT(inferred_status_error2.status().error_message(),
2329 HasSubstr("branch operand 2 must match the shape of the only "
2330 "parameter of branch computation 2"));
2331
2332 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2333 r0s32,
2334 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2335 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2336 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2337 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2338 {vector_32_, vector_32_, vector_32_, vector_64_});
2339 EXPECT_FALSE(inferred_status_error3.ok());
2340 EXPECT_THAT(inferred_status_error3.status().error_message(),
2341 HasSubstr("the result of branch 0 computation and branch 3 "
2342 "computation must have the same shape"));
2343
2344 auto inferred_status_error4 =
2345 ShapeInference::InferConditionalShape(r0s32, {}, {});
2346 EXPECT_FALSE(inferred_status_error4.ok());
2347 EXPECT_THAT(inferred_status_error4.status().error_message(),
2348 HasSubstr("!branch_computations.empty()"));
2349 }
2350
TEST_F(ShapeInferenceTest,BadSlice)2351 TEST_F(ShapeInferenceTest, BadSlice) {
2352 auto arg = ShapeUtil::MakeShape(F32, {4});
2353 StatusOr<Shape> statusor =
2354 ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
2355 ASSERT_FALSE(statusor.ok());
2356
2357 LOG(INFO) << statusor.status();
2358
2359 EXPECT_THAT(statusor.status().error_message(),
2360 HasSubstr("less than or equal to dimension size"))
2361 << statusor.status();
2362 EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
2363 << statusor.status();
2364 }
2365
TEST_F(ShapeInferenceTest,BadSort)2366 TEST_F(ShapeInferenceTest, BadSort) {
2367 auto keys = ShapeUtil::MakeShape(F32, {4});
2368 auto values = ShapeUtil::MakeShape(F32, {5});
2369 StatusOr<Shape> statusor =
2370 ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
2371 EXPECT_FALSE(statusor.ok());
2372 EXPECT_THAT(statusor.status().error_message(),
2373 HasSubstr("dimensions must match"))
2374 << statusor.status();
2375 }
2376
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)2377 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
2378 auto keys = ShapeUtil::MakeShape(F32, {4});
2379 auto values_good = ShapeUtil::MakeShape(F32, {4});
2380 auto values_bad = ShapeUtil::MakeShape(F32, {5});
2381 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2382 HloOpcode::kSort, {&keys, &values_good, &values_bad});
2383 EXPECT_FALSE(statusor.ok());
2384 EXPECT_THAT(statusor.status().error_message(),
2385 HasSubstr("dimensions must match"))
2386 << statusor.status();
2387 }
2388
TEST_F(ShapeInferenceTest,SortManyValues)2389 TEST_F(ShapeInferenceTest, SortManyValues) {
2390 auto keys = ShapeUtil::MakeShape(F32, {4});
2391 auto values_s32 = ShapeUtil::MakeShape(S32, {4});
2392 auto values_u32 = ShapeUtil::MakeShape(U32, {4});
2393 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2394 HloOpcode::kSort, {&keys, &values_s32, &values_u32});
2395 EXPECT_IS_OK(statusor);
2396 Shape inferred_shape = statusor.ValueOrDie();
2397 EXPECT_TRUE(ShapeUtil::Compatible(
2398 inferred_shape,
2399 ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
2400 }
2401
2402 class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
2403 protected:
2404 const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
2405 const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
2406 const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
2407 const Shape s64_4d_tensor_10_9_8_7_1_ =
2408 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
2409 const Shape s64_4d_tensor_10_9_8_7_5_ =
2410 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
2411 const Shape s64_4d_tensor_5_10_9_7_6_ =
2412 ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
2413 const Shape s64_4d_tensor_10_9_5_7_6_ =
2414 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
2415 const Shape f32_5d_tensor_50_49_48_47_46_ =
2416 ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
2417 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
2418 {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
2419 const ProgramShape to_apply_ =
2420 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
2421 };
2422
2423 // Shape inference tests for Gather.
2424
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGather)2425 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
2426 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2427 ShapeInference::InferGatherShape(
2428 matrix_64_48_, s64_vector_32_,
2429 HloGatherInstruction::MakeGatherDimNumbers(
2430 /*offset_dims=*/{0},
2431 /*collapsed_slice_dims=*/{1},
2432 /*start_index_map=*/{1},
2433 /*index_vector_dim=*/1),
2434 /*slice_sizes=*/{64, 1}));
2435 EXPECT_TRUE(
2436 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
2437 << ShapeUtil::HumanString(gather_shape);
2438 }
2439
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherV2)2440 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
2441 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2442 ShapeInference::InferGatherShape(
2443 matrix_64_48_, s64_vector_32_,
2444 HloGatherInstruction::MakeGatherDimNumbers(
2445 /*offset_dims=*/{1},
2446 /*collapsed_slice_dims=*/{0},
2447 /*start_index_map=*/{0},
2448 /*index_vector_dim=*/1),
2449 /*slice_sizes=*/{1, 48}));
2450 EXPECT_TRUE(
2451 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
2452 << ShapeUtil::HumanString(gather_shape);
2453 }
2454
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherNd)2455 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
2456 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2457 ShapeInference::InferGatherShape(
2458 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2459 HloGatherInstruction::MakeGatherDimNumbers(
2460 /*offset_dims=*/{4},
2461 /*collapsed_slice_dims=*/{0},
2462 /*start_index_map=*/{0},
2463 /*index_vector_dim=*/4),
2464 /*slice_sizes=*/{1, 48}));
2465 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2466 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
2467 << ShapeUtil::HumanString(gather_shape);
2468 }
2469
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowBatchDynamicSlice)2470 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
2471 TF_ASSERT_OK_AND_ASSIGN(
2472 Shape gather_shape,
2473 ShapeInference::InferGatherShape(
2474 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2475 HloGatherInstruction::MakeGatherDimNumbers(
2476 /*offset_dims=*/{4, 5, 6, 7, 8},
2477 /*collapsed_slice_dims=*/{},
2478 /*start_index_map=*/{0, 1, 2, 3, 4},
2479 /*index_vector_dim=*/4),
2480 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2481 EXPECT_TRUE(ShapeUtil::Equal(
2482 gather_shape,
2483 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
2484 << ShapeUtil::HumanString(gather_shape);
2485 }
2486
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherEntireDimension)2487 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherEntireDimension) {
2488 TF_ASSERT_OK_AND_ASSIGN(
2489 Shape gather_shape,
2490 ShapeInference::InferGatherShape(
2491 ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
2492 ShapeUtil::MakeShape(S64, {}),
2493 HloGatherInstruction::MakeGatherDimNumbers(
2494 /*offset_dims=*/{0, 1},
2495 /*collapsed_slice_dims=*/{0},
2496 /*start_index_map=*/{0},
2497 /*index_vector_dim=*/0),
2498 /*slice_sizes=*/{1, 2, 1}));
2499 EXPECT_TRUE(ShapeUtil::Equal(
2500 gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
2501 << ShapeUtil::HumanString(gather_shape);
2502 }
2503
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherCollapsedDimension)2504 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
2505 TF_ASSERT_OK_AND_ASSIGN(
2506 Shape gather_shape,
2507 ShapeInference::InferGatherShape(
2508 ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
2509 ShapeUtil::MakeShape(S64, {}),
2510 HloGatherInstruction::MakeGatherDimNumbers(
2511 /*offset_dims=*/{0, 1},
2512 /*collapsed_slice_dims=*/{0},
2513 /*start_index_map=*/{0},
2514 /*index_vector_dim=*/0),
2515 /*slice_sizes=*/{1, 2, 1}));
2516 EXPECT_TRUE(ShapeUtil::Equal(
2517 gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
2518 << ShapeUtil::HumanString(gather_shape);
2519 }
2520
TEST_F(ScatterGatherShapeInferenceTest,DynamicIndices)2521 TEST_F(ScatterGatherShapeInferenceTest, DynamicIndices) {
2522 TF_ASSERT_OK_AND_ASSIGN(
2523 Shape gather_shape,
2524 ShapeInference::InferGatherShape(
2525 ShapeUtil::MakeShape(F32, {3, 2, 2}),
2526 ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
2527 HloGatherInstruction::MakeGatherDimNumbers(
2528 /*offset_dims=*/{2, 3},
2529 /*collapsed_slice_dims=*/{0},
2530 /*start_index_map=*/{0, 1},
2531 /*index_vector_dim=*/2),
2532 /*slice_sizes=*/{1, 2, 2}));
2533 EXPECT_TRUE(ShapeUtil::Equal(
2534 gather_shape,
2535 ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
2536 << ShapeUtil::HumanString(gather_shape);
2537 }
2538
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)2539 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
2540 TF_ASSERT_OK_AND_ASSIGN(
2541 Shape gather_shape,
2542 ShapeInference::InferGatherShape(
2543 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2544 HloGatherInstruction::MakeGatherDimNumbers(
2545 /*offset_dims=*/{4, 5, 6, 7, 8},
2546 /*collapsed_slice_dims=*/{},
2547 /*start_index_map=*/{0, 1, 2, 3, 4},
2548 /*index_vector_dim=*/2),
2549 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2550
2551 EXPECT_TRUE(ShapeUtil::Equal(
2552 gather_shape,
2553 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2554 << ShapeUtil::HumanString(gather_shape);
2555 }
2556
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)2557 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
2558 TF_ASSERT_OK_AND_ASSIGN(
2559 Shape gather_shape,
2560 ShapeInference::InferGatherShape(
2561 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2562 HloGatherInstruction::MakeGatherDimNumbers(
2563 /*offset_dims=*/{4, 5, 6, 7, 8},
2564 /*collapsed_slice_dims=*/{},
2565 /*start_index_map=*/{0, 1, 2, 3, 4},
2566 /*index_vector_dim=*/0),
2567 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2568
2569 EXPECT_TRUE(ShapeUtil::Equal(
2570 gather_shape,
2571 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2572 << ShapeUtil::HumanString(gather_shape);
2573 }
2574
TEST_F(ScatterGatherShapeInferenceTest,NoOutputGatherDims)2575 TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
2576 // This is equivalent to a dynamic slice.
2577 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2578 ShapeInference::InferGatherShape(
2579 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2580 HloGatherInstruction::MakeGatherDimNumbers(
2581 /*offset_dims=*/{0, 1, 2, 3, 4},
2582 /*collapsed_slice_dims=*/{},
2583 /*start_index_map=*/{0, 1, 2, 3, 4},
2584 /*index_vector_dim=*/0),
2585 /*slice_sizes=*/{30, 29, 28, 27, 26}));
2586
2587 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2588 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
2589 << ShapeUtil::HumanString(gather_shape);
2590 }
2591
TEST_F(ScatterGatherShapeInferenceTest,ScalarGatherIndices)2592 TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
2593 // The gather indices "tensor" is a scalar S here that's used to slice out
2594 // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
2595 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2596 ShapeInference::InferGatherShape(
2597 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2598 HloGatherInstruction::MakeGatherDimNumbers(
2599 /*offset_dims=*/{0, 1, 2, 3},
2600 /*collapsed_slice_dims=*/{0},
2601 /*start_index_map=*/{0},
2602 /*index_vector_dim=*/0),
2603 /*slice_sizes=*/{1, 30, 29, 28, 27}));
2604
2605 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2606 ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
2607 << ShapeUtil::HumanString(gather_shape);
2608 }
2609
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedTensorInput)2610 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
2611 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2612 tuple_shape_, s64_vector_32_,
2613 HloGatherInstruction::MakeGatherDimNumbers(
2614 /*offset_dims=*/{0},
2615 /*collapsed_slice_dims=*/{1},
2616 /*start_index_map=*/{1},
2617 /*index_vector_dim=*/1),
2618 /*slice_sizes=*/{64, 1});
2619 ASSERT_FALSE(statusor.ok());
2620 EXPECT_THAT(statusor.status().error_message(),
2621 HasSubstr("Expected array argument for input"))
2622 << statusor.status();
2623 }
2624
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedGatherIndicesInput)2625 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
2626 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2627 s64_vector_32_, tuple_shape_,
2628 HloGatherInstruction::MakeGatherDimNumbers(
2629 /*offset_dims=*/{0},
2630 /*collapsed_slice_dims=*/{1},
2631 /*start_index_map=*/{1},
2632 /*index_vector_dim=*/0),
2633 /*slice_sizes=*/{64, 1});
2634 ASSERT_FALSE(statusor.ok());
2635 EXPECT_THAT(statusor.status().error_message(),
2636 HasSubstr("Expected array argument for gather indices"))
2637 << statusor.status();
2638 }
2639
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointGatherIndicesInput)2640 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
2641 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2642 s64_vector_32_, vector_32_,
2643 HloGatherInstruction::MakeGatherDimNumbers(
2644 /*offset_dims=*/{0},
2645 /*collapsed_slice_dims=*/{1},
2646 /*start_index_map=*/{1},
2647 /*index_vector_dim=*/0),
2648 /*slice_sizes=*/{64, 1});
2649 ASSERT_FALSE(statusor.ok());
2650 EXPECT_THAT(statusor.status().error_message(),
2651 HasSubstr("Gather indices parameter must be an integral tensor"))
2652 << statusor.status();
2653 }
2654
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)2655 TEST_F(ScatterGatherShapeInferenceTest,
2656 InvalidGatherDimNumbers_NonAscendingWindowIndices) {
2657 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2658 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2659 HloGatherInstruction::MakeGatherDimNumbers(
2660 /*offset_dims=*/{4, 5, 6, 8, 7},
2661 /*collapsed_slice_dims=*/{},
2662 /*start_index_map=*/{0, 1, 2, 3, 4},
2663 /*index_vector_dim=*/4),
2664 /*slice_sizes=*/{30, 29, 28, 27, 26});
2665 ASSERT_FALSE(statusor.ok());
2666 EXPECT_THAT(
2667 statusor.status().error_message(),
2668 HasSubstr("Output window dimensions in gather op must be ascending"))
2669 << statusor.status();
2670 }
2671
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)2672 TEST_F(ScatterGatherShapeInferenceTest,
2673 InvalidGatherDimNumbers_RepeatedWindowIndices) {
2674 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2675 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2676 HloGatherInstruction::MakeGatherDimNumbers(
2677 /*offset_dims=*/{4, 5, 6, 7, 7},
2678 /*collapsed_slice_dims=*/{},
2679 /*start_index_map=*/{0, 1, 2, 3, 4},
2680 /*index_vector_dim=*/4),
2681 /*slice_sizes=*/{30, 29, 28, 27, 26});
2682 ASSERT_FALSE(statusor.ok());
2683 EXPECT_THAT(
2684 statusor.status().error_message(),
2685 HasSubstr("Output window dimensions in gather op must not repeat"))
2686 << statusor.status();
2687 }
2688
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)2689 TEST_F(ScatterGatherShapeInferenceTest,
2690 InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
2691 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2692 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2693 HloGatherInstruction::MakeGatherDimNumbers(
2694 /*offset_dims=*/{4, 5, 99, 100, 101},
2695 /*collapsed_slice_dims=*/{},
2696 /*start_index_map=*/{0, 1, 2, 3, 4},
2697 /*index_vector_dim=*/4),
2698 /*slice_sizes=*/{30, 29, 28, 27, 26});
2699 ASSERT_FALSE(statusor.ok());
2700 EXPECT_THAT(statusor.status().error_message(),
2701 HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2702 << statusor.status();
2703 }
2704
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2705 TEST_F(ScatterGatherShapeInferenceTest,
2706 InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2707 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2708 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2709 HloGatherInstruction::MakeGatherDimNumbers(
2710 /*offset_dims=*/{4, 5, 6, 7, 9},
2711 /*collapsed_slice_dims=*/{},
2712 /*start_index_map=*/{0, 1, 2, 3, 4},
2713 /*index_vector_dim=*/4),
2714 /*slice_sizes=*/{30, 29, 28, 27, 26});
2715 ASSERT_FALSE(statusor.ok());
2716 EXPECT_THAT(statusor.status().error_message(),
2717 HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2718 << statusor.status();
2719 }
2720
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2721 TEST_F(ScatterGatherShapeInferenceTest,
2722 InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2723 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2724 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2725 HloGatherInstruction::MakeGatherDimNumbers(
2726 /*offset_dims=*/{4, 5, 6, 7, 8},
2727 /*collapsed_slice_dims=*/{4},
2728 /*start_index_map=*/{0, 1, 2, 3, 4},
2729 /*index_vector_dim=*/4),
2730 /*slice_sizes=*/{30, 29, 28, 27, 26});
2731 ASSERT_FALSE(statusor.ok());
2732 EXPECT_THAT(
2733 statusor.status().error_message(),
2734 HasSubstr("All components of the offset index in a gather op must either "
2735 "be a offset dimension or explicitly collapsed"))
2736 << statusor.status();
2737 }
2738
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2739 TEST_F(ScatterGatherShapeInferenceTest,
2740 InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2741 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2742 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2743 HloGatherInstruction::MakeGatherDimNumbers(
2744 /*offset_dims=*/{4, 5, 6, 7, 8},
2745 /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2746 /*start_index_map=*/{0, 1, 2, 3, 4},
2747 /*index_vector_dim=*/4),
2748 /*slice_sizes=*/{30, 29, 28, 27, 26});
2749 ASSERT_FALSE(statusor.ok());
2750 EXPECT_THAT(statusor.status().error_message(),
2751 HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2752 "range is [0, 5), got: 19"))
2753 << statusor.status();
2754 }
2755
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2756 TEST_F(ScatterGatherShapeInferenceTest,
2757 InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2758 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2759 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2760 HloGatherInstruction::MakeGatherDimNumbers(
2761 /*offset_dims=*/{4, 5, 6, 7, 8},
2762 /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2763 /*start_index_map=*/{0, 1, 2, 3, 4},
2764 /*index_vector_dim=*/4),
2765 /*slice_sizes=*/{30, 29, 28, 27, 26});
2766 ASSERT_FALSE(statusor.ok());
2767 EXPECT_THAT(statusor.status().error_message(),
2768 HasSubstr("Repeated dimensions not allowed in "
2769 "collapsed_slice_dims in gather op"))
2770 << statusor.status();
2771 }
2772
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2773 TEST_F(ScatterGatherShapeInferenceTest,
2774 InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2775 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2776 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2777 HloGatherInstruction::MakeGatherDimNumbers(
2778 /*offset_dims=*/{4, 5, 6, 7, 8},
2779 /*collapsed_slice_dims=*/{},
2780 /*start_index_map=*/{0, 1, 2, 3},
2781 /*index_vector_dim=*/4),
2782 /*slice_sizes=*/{30, 29, 28, 27, 26});
2783 ASSERT_FALSE(statusor.ok());
2784 EXPECT_THAT(statusor.status().error_message(),
2785 HasSubstr("Gather op has 4 elements in start_index_map and "
2786 "the bound of dimension index_vector_dim=4 of "
2787 "start_indices is 5. These two numbers must be equal."))
2788 << statusor.status();
2789 }
2790
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2791 TEST_F(ScatterGatherShapeInferenceTest,
2792 InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2793 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2794 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2795 HloGatherInstruction::MakeGatherDimNumbers(
2796 /*offset_dims=*/{4, 5, 6, 7, 8},
2797 /*collapsed_slice_dims=*/{},
2798 /*start_index_map=*/{0, 1, 2, 3, 7},
2799 /*index_vector_dim=*/4),
2800 /*slice_sizes=*/{30, 29, 28, 27, 26});
2801 ASSERT_FALSE(statusor.ok());
2802 EXPECT_THAT(statusor.status().error_message(),
2803 HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2804 << statusor.status();
2805 }
2806
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2807 TEST_F(ScatterGatherShapeInferenceTest,
2808 InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2809 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2810 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2811 HloGatherInstruction::MakeGatherDimNumbers(
2812 /*offset_dims=*/{4, 5, 6, 7, 8},
2813 /*collapsed_slice_dims=*/{},
2814 /*start_index_map=*/{0, 1, 2, 3, 3},
2815 /*index_vector_dim=*/4),
2816 /*slice_sizes=*/{30, 29, 28, 27, 26});
2817 ASSERT_FALSE(statusor.ok());
2818 EXPECT_THAT(
2819 statusor.status().error_message(),
2820 HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2821 << statusor.status();
2822 }
2823
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2824 TEST_F(ScatterGatherShapeInferenceTest,
2825 InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2826 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2827 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2828 HloGatherInstruction::MakeGatherDimNumbers(
2829 /*offset_dims=*/{4, 5, 6, 7, 8},
2830 /*collapsed_slice_dims=*/{2, 1},
2831 /*start_index_map=*/{0, 1, 2, 3, 4},
2832 /*index_vector_dim=*/4),
2833 /*slice_sizes=*/{1, 1, 28, 27, 26});
2834 ASSERT_FALSE(statusor.ok());
2835 EXPECT_THAT(statusor.status().error_message(),
2836 HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2837 << statusor.status();
2838 }
2839
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2840 TEST_F(ScatterGatherShapeInferenceTest,
2841 InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2842 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2843 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2844 HloGatherInstruction::MakeGatherDimNumbers(
2845 /*offset_dims=*/{4, 5, 6, 7},
2846 /*collapsed_slice_dims=*/{2},
2847 /*start_index_map=*/{0, 1, 2, 3, 4},
2848 /*index_vector_dim=*/4),
2849 /*slice_sizes=*/{30, 29, 1, 300, 26});
2850 ASSERT_FALSE(statusor.ok());
2851 EXPECT_THAT(statusor.status().error_message(),
2852 HasSubstr("Slice size at index 3 in gather op is out of range, "
2853 "must be within [0, 48), got 300."))
2854 << statusor.status();
2855 }
2856
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2857 TEST_F(ScatterGatherShapeInferenceTest,
2858 InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2859 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2860 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2861 HloGatherInstruction::MakeGatherDimNumbers(
2862 /*offset_dims=*/{4, 5, 6, 7, 8},
2863 /*collapsed_slice_dims=*/{},
2864 /*start_index_map=*/{0, 1, 2, 3, 4},
2865 /*index_vector_dim=*/4),
2866 /*slice_sizes=*/{30, 29, 28, 26});
2867 ASSERT_FALSE(statusor.ok());
2868 EXPECT_THAT(
2869 statusor.status().error_message(),
2870 HasSubstr("Gather op must have one slice size for every input dimension"))
2871 << statusor.status();
2872 }
2873
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2874 TEST_F(ScatterGatherShapeInferenceTest,
2875 InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2876 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2877 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2878 HloGatherInstruction::MakeGatherDimNumbers(
2879 /*offset_dims=*/{4, 5, 6, 7},
2880 /*collapsed_slice_dims=*/{1},
2881 /*start_index_map=*/{0, 1, 2, 3, 4},
2882 /*index_vector_dim=*/4),
2883 /*slice_sizes=*/{30, 29, 28, 26, 20});
2884 ASSERT_FALSE(statusor.ok());
2885 EXPECT_THAT(
2886 statusor.status().error_message(),
2887 HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
2888 "but bound is 29 for index 1 at position 0."))
2889 << statusor.status();
2890 }
2891
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2892 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2893 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2894 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2895 HloGatherInstruction::MakeGatherDimNumbers(
2896 /*offset_dims=*/{4, 5, 6, 7, 8},
2897 /*collapsed_slice_dims=*/{},
2898 /*start_index_map=*/{0, 1, 2, 3, 4},
2899 /*index_vector_dim=*/32),
2900 /*slice_sizes=*/{30, 29, 28, 27, 26});
2901
2902 ASSERT_FALSE(statusor.ok());
2903 EXPECT_THAT(statusor.status().error_message(),
2904 HasSubstr("Gather index leaf dimension must be within [0, "
2905 "rank(start_indices) + 1)"))
2906 << statusor.status();
2907 }
2908
2909 // Shape inference tests for Scatter.
2910
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdates)2911 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
2912 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2913 ShapeInference::InferScatterShape(
2914 matrix_64_48_, s64_vector_32_,
2915 ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
2916 HloScatterInstruction::MakeScatterDimNumbers(
2917 /*update_window_dims=*/{0},
2918 /*inserted_window_dims=*/{1},
2919 /*scatter_dims_to_operand_dims=*/{1},
2920 /*index_vector_dim=*/1)));
2921 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2922 << ShapeUtil::HumanString(scatter_shape);
2923 }
2924
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdatesV2)2925 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
2926 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2927 ShapeInference::InferScatterShape(
2928 matrix_64_48_, s64_vector_32_,
2929 ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
2930 HloScatterInstruction::MakeScatterDimNumbers(
2931 /*update_window_dims=*/{1},
2932 /*inserted_window_dims=*/{0},
2933 /*scatter_dims_to_operand_dims=*/{0},
2934 /*index_vector_dim=*/1)));
2935 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2936 << ShapeUtil::HumanString(scatter_shape);
2937 }
2938
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdates)2939 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
2940 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2941 ShapeInference::InferScatterShape(
2942 matrix_64_48_, s64_vector_32_,
2943 ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
2944 HloScatterInstruction::MakeScatterDimNumbers(
2945 /*update_window_dims=*/{0},
2946 /*inserted_window_dims=*/{1},
2947 /*scatter_dims_to_operand_dims=*/{1},
2948 /*index_vector_dim=*/1)));
2949 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2950 << ShapeUtil::HumanString(scatter_shape);
2951 }
2952
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdatesV2)2953 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
2954 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2955 ShapeInference::InferScatterShape(
2956 matrix_64_48_, s64_vector_32_,
2957 ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
2958 HloScatterInstruction::MakeScatterDimNumbers(
2959 /*update_window_dims=*/{1},
2960 /*inserted_window_dims=*/{0},
2961 /*scatter_dims_to_operand_dims=*/{0},
2962 /*index_vector_dim=*/1)));
2963 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2964 << ShapeUtil::HumanString(scatter_shape);
2965 }
2966
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)2967 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
2968 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2969 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
2970 to_apply_,
2971 HloScatterInstruction::MakeScatterDimNumbers(
2972 /*update_window_dims=*/{0},
2973 /*inserted_window_dims=*/{1},
2974 /*scatter_dims_to_operand_dims=*/{1},
2975 /*index_vector_dim=*/1));
2976 ASSERT_FALSE(statusor.ok());
2977 EXPECT_THAT(
2978 statusor.status().error_message(),
2979 HasSubstr("Bounds of the window dimensions of updates must not exceed "
2980 "the bounds of the corresponding dimensions of operand."))
2981 << statusor.status();
2982 }
2983
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)2984 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
2985 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2986 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
2987 to_apply_,
2988 HloScatterInstruction::MakeScatterDimNumbers(
2989 /*update_window_dims=*/{1},
2990 /*inserted_window_dims=*/{0},
2991 /*scatter_dims_to_operand_dims=*/{1},
2992 /*index_vector_dim=*/1));
2993 ASSERT_FALSE(statusor.ok());
2994 EXPECT_THAT(
2995 statusor.status().error_message(),
2996 HasSubstr("Bounds of the window dimensions of updates must not exceed "
2997 "the bounds of the corresponding dimensions of operand."))
2998 << statusor.status();
2999 }
3000
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)3001 TEST_F(ScatterGatherShapeInferenceTest,
3002 TfScatterWithUpdatesNotMatchingIndices) {
3003 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3004 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
3005 to_apply_,
3006 HloScatterInstruction::MakeScatterDimNumbers(
3007 /*update_window_dims=*/{0},
3008 /*inserted_window_dims=*/{1},
3009 /*scatter_dims_to_operand_dims=*/{1},
3010 /*index_vector_dim=*/1));
3011 ASSERT_FALSE(statusor.ok());
3012 EXPECT_THAT(
3013 statusor.status().error_message(),
3014 HasSubstr(
3015 "Bounds of the scatter dimensions of updates must be same as the "
3016 "bounds of the corresponding dimensions of scatter indices."))
3017 << statusor.status();
3018 }
3019
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)3020 TEST_F(ScatterGatherShapeInferenceTest,
3021 TfScatterWithUpdatesNotMatchingIndicesV2) {
3022 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3023 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
3024 to_apply_,
3025 HloScatterInstruction::MakeScatterDimNumbers(
3026 /*update_window_dims=*/{1},
3027 /*inserted_window_dims=*/{0},
3028 /*scatter_dims_to_operand_dims=*/{1},
3029 /*index_vector_dim=*/1));
3030 ASSERT_FALSE(statusor.ok());
3031 EXPECT_THAT(
3032 statusor.status().error_message(),
3033 HasSubstr(
3034 "Bounds of the scatter dimensions of updates must be same as the "
3035 "bounds of the corresponding dimensions of scatter indices."))
3036 << statusor.status();
3037 }
3038
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdates)3039 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
3040 TF_ASSERT_OK_AND_ASSIGN(
3041 Shape scatter_shape,
3042 ShapeInference::InferScatterShape(
3043 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3044 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
3045 HloScatterInstruction::MakeScatterDimNumbers(
3046 /*update_window_dims=*/{4},
3047 /*inserted_window_dims=*/{0},
3048 /*scatter_dims_to_operand_dims=*/{0},
3049 /*index_vector_dim=*/4)));
3050 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3051 << ShapeUtil::HumanString(scatter_shape);
3052 }
3053
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdatesV2)3054 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
3055 TF_ASSERT_OK_AND_ASSIGN(
3056 Shape scatter_shape,
3057 ShapeInference::InferScatterShape(
3058 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3059 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
3060 HloScatterInstruction::MakeScatterDimNumbers(
3061 /*update_window_dims=*/{4},
3062 /*inserted_window_dims=*/{1},
3063 /*scatter_dims_to_operand_dims=*/{0},
3064 /*index_vector_dim=*/4)));
3065 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3066 << ShapeUtil::HumanString(scatter_shape);
3067 }
3068
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdates)3069 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
3070 TF_ASSERT_OK_AND_ASSIGN(
3071 Shape scatter_shape,
3072 ShapeInference::InferScatterShape(
3073 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3074 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
3075 HloScatterInstruction::MakeScatterDimNumbers(
3076 /*update_window_dims=*/{4},
3077 /*inserted_window_dims=*/{0},
3078 /*scatter_dims_to_operand_dims=*/{0},
3079 /*index_vector_dim=*/4)));
3080 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3081 << ShapeUtil::HumanString(scatter_shape);
3082 }
3083
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)3084 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
3085 TF_ASSERT_OK_AND_ASSIGN(
3086 Shape scatter_shape,
3087 ShapeInference::InferScatterShape(
3088 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3089 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
3090 HloScatterInstruction::MakeScatterDimNumbers(
3091 /*update_window_dims=*/{4},
3092 /*inserted_window_dims=*/{1},
3093 /*scatter_dims_to_operand_dims=*/{0},
3094 /*index_vector_dim=*/4)));
3095 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3096 << ShapeUtil::HumanString(scatter_shape);
3097 }
3098
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)3099 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
3100 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3101 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3102 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
3103 HloScatterInstruction::MakeScatterDimNumbers(
3104 /*update_window_dims=*/{4},
3105 /*inserted_window_dims=*/{1},
3106 /*scatter_dims_to_operand_dims=*/{0},
3107 /*index_vector_dim=*/4));
3108 ASSERT_FALSE(statusor.ok());
3109 EXPECT_THAT(
3110 statusor.status().error_message(),
3111 HasSubstr("Bounds of the window dimensions of updates must not exceed "
3112 "the bounds of the corresponding dimensions of operand."))
3113 << statusor.status();
3114 }
3115
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)3116 TEST_F(ScatterGatherShapeInferenceTest,
3117 TfScatterNdWithUpdatesNotMatchingIndices) {
3118 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3119 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3120 ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
3121 HloScatterInstruction::MakeScatterDimNumbers(
3122 /*update_window_dims=*/{4},
3123 /*inserted_window_dims=*/{1},
3124 /*scatter_dims_to_operand_dims=*/{0},
3125 /*index_vector_dim=*/4));
3126 ASSERT_FALSE(statusor.ok());
3127 EXPECT_THAT(
3128 statusor.status().error_message(),
3129 HasSubstr(
3130 "Bounds of the scatter dimensions of updates must be same as the "
3131 "bounds of the corresponding dimensions of scatter indices."))
3132 << statusor.status();
3133 }
3134
TEST_F(ScatterGatherShapeInferenceTest,TfBatchDynamicUpdateSlice)3135 TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
3136 TF_ASSERT_OK_AND_ASSIGN(
3137 Shape scatter_shape,
3138 ShapeInference::InferScatterShape(
3139 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3140 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
3141 to_apply_,
3142 HloScatterInstruction::MakeScatterDimNumbers(
3143 /*update_window_dims=*/{4, 5, 6, 7, 8},
3144 /*inserted_window_dims=*/{},
3145 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3146 /*index_vector_dim=*/4)));
3147 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3148 << ShapeUtil::HumanString(scatter_shape);
3149 }
3150
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDim)3151 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
3152 TF_ASSERT_OK_AND_ASSIGN(
3153 Shape scatter_shape,
3154 ShapeInference::InferScatterShape(
3155 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
3156 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3157 to_apply_,
3158 HloScatterInstruction::MakeScatterDimNumbers(
3159 /*update_window_dims=*/{4, 5, 6, 7, 8},
3160 /*inserted_window_dims=*/{},
3161 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3162 /*index_vector_dim=*/2)));
3163
3164 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3165 << ShapeUtil::HumanString(scatter_shape);
3166 }
3167
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)3168 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
3169 TF_ASSERT_OK_AND_ASSIGN(
3170 Shape scatter_shape,
3171 ShapeInference::InferScatterShape(
3172 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
3173 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3174 to_apply_,
3175 HloScatterInstruction::MakeScatterDimNumbers(
3176 /*update_window_dims=*/{4, 5, 6, 7, 8},
3177 /*inserted_window_dims=*/{},
3178 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3179 /*index_vector_dim=*/0)));
3180
3181 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3182 << ShapeUtil::HumanString(scatter_shape);
3183 }
3184
TEST_F(ScatterGatherShapeInferenceTest,NoUpdateScatterDims)3185 TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
3186 // This is equivalent to a dynamic update slice.
3187 TF_ASSERT_OK_AND_ASSIGN(
3188 Shape scatter_shape,
3189 ShapeInference::InferScatterShape(
3190 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
3191 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
3192 HloScatterInstruction::MakeScatterDimNumbers(
3193 /*update_window_dims=*/{0, 1, 2, 3, 4},
3194 /*inserted_window_dims=*/{},
3195 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3196 /*index_vector_dim=*/0)));
3197
3198 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3199 << ShapeUtil::HumanString(scatter_shape);
3200 }
3201
TEST_F(ScatterGatherShapeInferenceTest,ScalarScatterIndices)3202 TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
3203 // The scalar indices "tensor" is a scalar S here that's used to update a
3204 // [30,29,28,27] shaped tensor within the operand at position S.
3205 TF_ASSERT_OK_AND_ASSIGN(
3206 Shape scatter_shape,
3207 ShapeInference::InferScatterShape(
3208 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3209 ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3210 HloScatterInstruction::MakeScatterDimNumbers(
3211 /*update_window_dims=*/{0, 1, 2, 3},
3212 /*inserted_window_dims=*/{0},
3213 /*scatter_dims_to_operand_dims=*/{0},
3214 /*index_vector_dim=*/0)));
3215
3216 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3217 << ShapeUtil::HumanString(scatter_shape);
3218 }
3219
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedTensorInput)3220 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
3221 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3222 tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
3223 HloScatterInstruction::MakeScatterDimNumbers(
3224 /*update_window_dims=*/{0},
3225 /*inserted_window_dims=*/{1},
3226 /*scatter_dims_to_operand_dims=*/{1},
3227 /*index_vector_dim=*/1));
3228 ASSERT_FALSE(statusor.ok());
3229 EXPECT_THAT(statusor.status().error_message(),
3230 HasSubstr("Expected array argument for operand"))
3231 << statusor.status();
3232 }
3233
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)3234 TEST_F(ScatterGatherShapeInferenceTest,
3235 ScatterWithTupleShapedScatterIndicesInput) {
3236 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3237 s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
3238 HloScatterInstruction::MakeScatterDimNumbers(
3239 /*update_window_dims=*/{0},
3240 /*inserted_window_dims=*/{1},
3241 /*scatter_dims_to_operand_dims=*/{1},
3242 /*index_vector_dim=*/0));
3243 ASSERT_FALSE(statusor.ok());
3244 EXPECT_THAT(statusor.status().error_message(),
3245 HasSubstr("Expected array argument for scatter indices"))
3246 << statusor.status();
3247 }
3248
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)3249 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
3250 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3251 s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
3252 HloScatterInstruction::MakeScatterDimNumbers(
3253 /*update_window_dims=*/{0},
3254 /*inserted_window_dims=*/{1},
3255 /*scatter_dims_to_operand_dims=*/{1},
3256 /*index_vector_dim=*/0));
3257 ASSERT_FALSE(statusor.ok());
3258 EXPECT_THAT(statusor.status().error_message(),
3259 HasSubstr("Expected array argument for updates"))
3260 << statusor.status();
3261 }
3262
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointScatterIndicesInput)3263 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
3264 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3265 s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
3266 HloScatterInstruction::MakeScatterDimNumbers(
3267 /*update_window_dims=*/{0},
3268 /*inserted_window_dims=*/{1},
3269 /*scatter_dims_to_operand_dims=*/{1},
3270 /*index_vector_dim=*/0));
3271 ASSERT_FALSE(statusor.ok());
3272 EXPECT_THAT(statusor.status().error_message(),
3273 HasSubstr("Scatter indices parameter must be an integral tensor"))
3274 << statusor.status();
3275 }
3276
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)3277 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
3278 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3279 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3280 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3281 HloScatterInstruction::MakeScatterDimNumbers(
3282 /*update_window_dims=*/{4, 5, 6},
3283 /*inserted_window_dims=*/{1, 2},
3284 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3285 /*index_vector_dim=*/10));
3286 ASSERT_FALSE(statusor.ok());
3287 EXPECT_THAT(statusor.status().error_message(),
3288 HasSubstr("Scatter index leaf dimension must be within [0, "
3289 "rank(scatter_indices) + 1)"))
3290 << statusor.status();
3291 }
3292
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdates)3293 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
3294 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3295 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3296 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
3297 HloScatterInstruction::MakeScatterDimNumbers(
3298 /*update_window_dims=*/{4, 5, 6},
3299 /*inserted_window_dims=*/{1, 2},
3300 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3301 /*index_vector_dim=*/4));
3302 ASSERT_FALSE(statusor.ok());
3303 EXPECT_THAT(statusor.status().error_message(),
3304 HasSubstr("Updates tensor must be of rank 7; got 8."))
3305 << statusor.status();
3306 }
3307
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdateComputation)3308 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
3309 const ProgramShape invalid_update_computation =
3310 ShapeUtil::MakeProgramShape({f32_}, f32_);
3311 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3312 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3313 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
3314 invalid_update_computation,
3315 HloScatterInstruction::MakeScatterDimNumbers(
3316 /*update_window_dims=*/{4, 5, 6},
3317 /*inserted_window_dims=*/{1, 2},
3318 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3319 /*index_vector_dim=*/4));
3320 ASSERT_FALSE(statusor.ok());
3321 EXPECT_THAT(
3322 statusor.status().error_message(),
3323 HasSubstr("Reduction function must take 2 parameters, but takes 1"))
3324 << statusor.status();
3325 }
3326
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)3327 TEST_F(ScatterGatherShapeInferenceTest,
3328 InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
3329 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3330 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3331 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3332 HloScatterInstruction::MakeScatterDimNumbers(
3333 /*update_window_dims=*/{4, 5, 6, 8, 7},
3334 /*inserted_window_dims=*/{},
3335 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3336 /*index_vector_dim=*/4));
3337 ASSERT_FALSE(statusor.ok());
3338 EXPECT_THAT(statusor.status().error_message(),
3339 HasSubstr("update_window_dims in scatter op must be sorted"))
3340 << statusor.status();
3341 }
3342
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)3343 TEST_F(ScatterGatherShapeInferenceTest,
3344 InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
3345 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3346 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3347 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3348 HloScatterInstruction::MakeScatterDimNumbers(
3349 /*update_window_dims=*/{4, 5, 6, 7, 7},
3350 /*inserted_window_dims=*/{},
3351 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3352 /*index_vector_dim=*/4));
3353 ASSERT_FALSE(statusor.ok());
3354 EXPECT_THAT(statusor.status().error_message(),
3355 HasSubstr("update_window_dims in scatter op must not repeat"))
3356 << statusor.status();
3357 }
3358
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)3359 TEST_F(ScatterGatherShapeInferenceTest,
3360 InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
3361 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3362 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3363 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3364 HloScatterInstruction::MakeScatterDimNumbers(
3365 /*update_window_dims=*/{4, 5, 6, 7, 9},
3366 /*inserted_window_dims=*/{},
3367 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3368 /*index_vector_dim=*/4));
3369 ASSERT_FALSE(statusor.ok());
3370 EXPECT_THAT(statusor.status().error_message(),
3371 HasSubstr("Invalid update_window_dims set in scatter op; valid "
3372 "range is [0, 9)"))
3373 << statusor.status();
3374 }
3375
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)3376 TEST_F(ScatterGatherShapeInferenceTest,
3377 InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
3378 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3379 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3380 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3381 HloScatterInstruction::MakeScatterDimNumbers(
3382 /*update_window_dims=*/{4, 5, 6},
3383 /*inserted_window_dims=*/{2, 1},
3384 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3385 /*index_vector_dim=*/4));
3386 ASSERT_FALSE(statusor.ok());
3387 EXPECT_THAT(statusor.status().error_message(),
3388 HasSubstr("inserted_window_dims in scatter op must be sorted"))
3389 << statusor.status();
3390 }
3391
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)3392 TEST_F(ScatterGatherShapeInferenceTest,
3393 InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
3394 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3395 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3396 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3397 HloScatterInstruction::MakeScatterDimNumbers(
3398 /*update_window_dims=*/{4, 5, 6},
3399 /*inserted_window_dims=*/{1, 1},
3400 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3401 /*index_vector_dim=*/4));
3402 ASSERT_FALSE(statusor.ok());
3403 EXPECT_THAT(statusor.status().error_message(),
3404 HasSubstr("inserted_window_dims in scatter op must not repeat"))
3405 << statusor.status();
3406 }
3407
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)3408 TEST_F(ScatterGatherShapeInferenceTest,
3409 InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
3410 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3411 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3412 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3413 HloScatterInstruction::MakeScatterDimNumbers(
3414 /*update_window_dims=*/{4, 5, 6},
3415 /*inserted_window_dims=*/{1, 5},
3416 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3417 /*index_vector_dim=*/4));
3418 ASSERT_FALSE(statusor.ok());
3419 EXPECT_THAT(statusor.status().error_message(),
3420 HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
3421 "range is [0, 5)"))
3422 << statusor.status();
3423 }
3424
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)3425 TEST_F(ScatterGatherShapeInferenceTest,
3426 InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
3427 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3428 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3429 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3430 HloScatterInstruction::MakeScatterDimNumbers(
3431 /*update_window_dims=*/{4, 5, 6},
3432 /*inserted_window_dims=*/{1, 2},
3433 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
3434 /*index_vector_dim=*/4));
3435 ASSERT_FALSE(statusor.ok());
3436 EXPECT_THAT(
3437 statusor.status().error_message(),
3438 HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
3439 "the bound of dimension index_vector_dim=4 of scatter_indices "
3440 "is 5. These two numbers must be equal"))
3441 << statusor.status();
3442 }
3443
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)3444 TEST_F(ScatterGatherShapeInferenceTest,
3445 InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
3446 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3447 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3448 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3449 HloScatterInstruction::MakeScatterDimNumbers(
3450 /*update_window_dims=*/{4, 5, 6},
3451 /*inserted_window_dims=*/{1, 2},
3452 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
3453 /*index_vector_dim=*/4));
3454 ASSERT_FALSE(statusor.ok());
3455 EXPECT_THAT(statusor.status().error_message(),
3456 HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
3457 "is [0, 5), got: 4->10"))
3458 << statusor.status();
3459 }
3460
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)3461 TEST_F(ScatterGatherShapeInferenceTest,
3462 InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
3463 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3464 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3465 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3466 HloScatterInstruction::MakeScatterDimNumbers(
3467 /*update_window_dims=*/{4, 5, 6},
3468 /*inserted_window_dims=*/{1, 2},
3469 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
3470 /*index_vector_dim=*/4));
3471 ASSERT_FALSE(statusor.ok());
3472 EXPECT_THAT(
3473 statusor.status().error_message(),
3474 HasSubstr(
3475 "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
3476 << statusor.status();
3477 }
3478
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)3479 TEST_F(ScatterGatherShapeInferenceTest,
3480 InvalidScatterDimNumbers_InsufficientWindowDims) {
3481 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3482 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3483 ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3484 HloScatterInstruction::MakeScatterDimNumbers(
3485 /*update_window_dims=*/{0, 1, 2, 3},
3486 /*inserted_window_dims=*/{},
3487 /*scatter_dims_to_operand_dims=*/{0},
3488 /*index_vector_dim=*/0));
3489 ASSERT_FALSE(statusor.ok());
3490 EXPECT_THAT(
3491 statusor.status().error_message(),
3492 HasSubstr(
3493 "Scatter op has window of size 4; doesn't match operand of rank 5."))
3494 << statusor.status();
3495 }
3496
3497 } // namespace
3498 } // namespace xla
3499