• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/padding.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 namespace {
32 
33 using ::testing::ContainsRegex;
34 using ::testing::HasSubstr;
35 
36 class ShapeInferenceTest : public ::testing::Test {
37  protected:
38   // Some handy scalar shapes.
39   const Shape s32_ = ShapeUtil::MakeShape(S32, {});
40   const Shape f16_ = ShapeUtil::MakeShape(F16, {});
41   const Shape f32_ = ShapeUtil::MakeShape(F32, {});
42   const Shape f64_ = ShapeUtil::MakeShape(F64, {});
43   const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
44 
45   // Some handy vector and matrix shapes of F32 type.
46   // Suffix: vector_length_, matrix_rows_cols_
47   const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
48   const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
49   const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
50   const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
51   const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
52 
53   // Some handy S32 arrays.
54   const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
55 };
56 
57 // Subclass for testing InferReduceShape.
58 class ReduceShapeInferenceTest : public ShapeInferenceTest {
59  protected:
60   // Helper that runs reduce shape inference with the input 'arg' and given
61   // dimensions to reduce, and checks the inferred shape is as expected. The
62   // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64> dimensions_to_reduce)63   void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
64                                  const Shape& arg,
65                                  absl::Span<const int64> dimensions_to_reduce) {
66     ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
67     auto inferred_status = ShapeInference::InferReduceShape(
68         {&arg, &f32_}, dimensions_to_reduce, to_apply);
69     EXPECT_IS_OK(inferred_status.status());
70     EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
71                                  inferred_status.ValueOrDie()));
72   }
73 };
74 
75 // Subclass for testing InferSelectAndScatterShape.
76 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
77  protected:
SelectAndScatterShapeInferenceTest()78   SelectAndScatterShapeInferenceTest() {
79     operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
80     source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
81     WindowDimension dim;
82     dim.set_size(2);
83     dim.set_stride(2);
84     dim.set_padding_low(0);
85     dim.set_padding_high(0);
86     dim.set_window_dilation(1);
87     dim.set_base_dilation(1);
88     *window_.add_dimensions() = dim;
89     *window_.add_dimensions() = dim;
90     init_value_shape_ = ShapeUtil::MakeShape(F32, {});
91     select_program_shape_ = ShapeUtil::MakeProgramShape(
92         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
93     scatter_program_shape_ = ShapeUtil::MakeProgramShape(
94         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
95   }
96 
97   Shape operand_shape_;
98   Shape source_shape_;
99   Window window_;
100   Shape init_value_shape_;
101   ProgramShape select_program_shape_;
102   ProgramShape scatter_program_shape_;
103 };
104 
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)105 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
106   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
107   auto inferred_status =
108       ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
109   ASSERT_IS_OK(inferred_status.status());
110   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
111 }
112 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)113 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
114   Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
115   auto inferred_status = ShapeInference::InferTernaryOpShape(
116       HloOpcode::kSelect, pred_, tuple, tuple);
117   ASSERT_FALSE(inferred_status.ok());
118   ASSERT_THAT(inferred_status.status().error_message(),
119               HasSubstr("Expected array argument for select"));
120 }
121 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)122 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
123   auto inferred_status = ShapeInference::InferTernaryOpShape(
124       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
125   ASSERT_FALSE(inferred_status.ok());
126   ASSERT_THAT(
127       inferred_status.status().error_message(),
128       HasSubstr("Operands to select and predicate must be the same shape"));
129 }
130 
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)131 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
132   auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
133   auto inferred_status = ShapeInference::InferTernaryOpShape(
134       HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
135   ASSERT_IS_OK(inferred_status.status());
136   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
137 }
138 
TEST_F(ShapeInferenceTest,SelectBadShapes)139 TEST_F(ShapeInferenceTest, SelectBadShapes) {
140   auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
141       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
142   ASSERT_FALSE(inferred_status_error1.ok());
143   ASSERT_THAT(inferred_status_error1.status().error_message(),
144               HasSubstr("Operands to select must be the same shape"));
145 
146   auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
147       HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
148   ASSERT_FALSE(inferred_status_error2.ok());
149   ASSERT_THAT(inferred_status_error2.status().error_message(),
150               HasSubstr("pred operand must have PRED"));
151 
152   auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
153       HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
154       matrix_64_48_);
155   ASSERT_FALSE(inferred_status_error3.ok());
156   ASSERT_THAT(
157       inferred_status_error3.status().error_message(),
158       HasSubstr("Operands to select and predicate must be the same shape"));
159 
160   // Tuples have a TUPLE element type and cannot be the pred of a select.
161   auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
162       HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
163       ShapeUtil::MakeTupleShape({f32_, f32_}),
164       ShapeUtil::MakeTupleShape({f32_, f32_}));
165   ASSERT_FALSE(inferred_status_error4.ok());
166   ASSERT_THAT(inferred_status_error4.status().error_message(),
167               HasSubstr("Expected array argument for select pred"));
168 }
169 
TEST_F(ShapeInferenceTest,ClampAllMatrix)170 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
171   auto inferred_status = ShapeInference::InferTernaryOpShape(
172       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
173   ASSERT_IS_OK(inferred_status.status());
174   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
175 }
176 
TEST_F(ShapeInferenceTest,ClampAllScalar)177 TEST_F(ShapeInferenceTest, ClampAllScalar) {
178   auto inferred_status =
179       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
180   ASSERT_IS_OK(inferred_status.status());
181   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
182 }
183 
TEST_F(ShapeInferenceTest,ClampMinScalar)184 TEST_F(ShapeInferenceTest, ClampMinScalar) {
185   auto inferred_status = ShapeInference::InferTernaryOpShape(
186       HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
187   ASSERT_FALSE(inferred_status.ok());
188   ASSERT_THAT(inferred_status.status().error_message(),
189               HasSubstr("Clamp with different shapes"));
190 }
191 
TEST_F(ShapeInferenceTest,ClampMaxScalar)192 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
193   auto inferred_status = ShapeInference::InferTernaryOpShape(
194       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
195   ASSERT_FALSE(inferred_status.ok());
196   ASSERT_THAT(inferred_status.status().error_message(),
197               HasSubstr("Clamp with different shapes"));
198 }
199 
TEST_F(ShapeInferenceTest,ClampOperandScalar)200 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
201   auto inferred_status = ShapeInference::InferTernaryOpShape(
202       HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
203   ASSERT_FALSE(inferred_status.ok());
204   ASSERT_THAT(inferred_status.status().error_message(),
205               HasSubstr("Clamp with different shapes"));
206 }
207 
TEST_F(ShapeInferenceTest,ClampMinMatrix)208 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
209   auto inferred_status = ShapeInference::InferTernaryOpShape(
210       HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
211   ASSERT_FALSE(inferred_status.ok());
212   ASSERT_THAT(inferred_status.status().error_message(),
213               HasSubstr("Clamp with different shapes"));
214 }
215 
TEST_F(ShapeInferenceTest,ClampMaxMatrix)216 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
217   auto inferred_status = ShapeInference::InferTernaryOpShape(
218       HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
219   ASSERT_FALSE(inferred_status.ok());
220   ASSERT_THAT(inferred_status.status().error_message(),
221               HasSubstr("Clamp with different shapes"));
222 }
223 
TEST_F(ShapeInferenceTest,ClampOperandMatrix)224 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
225   auto inferred_status = ShapeInference::InferTernaryOpShape(
226       HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
227   ASSERT_FALSE(inferred_status.ok());
228   ASSERT_THAT(inferred_status.status().error_message(),
229               HasSubstr("Clamp with different shapes"));
230 }
231 
TEST_F(ShapeInferenceTest,ClampBadShapes)232 TEST_F(ShapeInferenceTest, ClampBadShapes) {
233   // Type mismatch
234   ASSERT_FALSE(
235       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
236           .ok());
237   ASSERT_FALSE(
238       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
239           .ok());
240   ASSERT_FALSE(
241       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
242           .ok());
243   // Dimension mismatch
244   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
245                    HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
246                    .ok());
247   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
248                    HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
249                    .ok());
250   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
251                    HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
252                    .ok());
253   // Dimension mismatch, where one operand is a scalar
254   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
255                                                    vector_64_, vector_32_, f32_)
256                    .ok());
257   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
258                                                    vector_64_, f32_, vector_32_)
259                    .ok());
260   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
261                                                    vector_64_, vector_32_)
262                    .ok());
263 }
264 
TEST_F(ShapeInferenceTest,Complex)265 TEST_F(ShapeInferenceTest, Complex) {
266   auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
267                            absl::Span<const int64> bcast) {
268     return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
269                                               bcast);
270   };
271   // Inputs must be FP.
272   ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
273   ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
274   // Component types must match.
275   ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
276   // Only F32->C64 and F64->C128 supported.
277   ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
278   // Validate correct uses.
279   Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
280   TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
281   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
282   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
283   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
284   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
285   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
286   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
287   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
288 
289   Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
290   TF_ASSERT_OK_AND_ASSIGN(result,
291                           complex_shape(vector_64_, matrix_32_64_, {1}));
292   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
293   TF_ASSERT_OK_AND_ASSIGN(result,
294                           complex_shape(matrix_32_64_, vector_64_, {1}));
295   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
296   TF_ASSERT_OK_AND_ASSIGN(result,
297                           complex_shape(matrix_32_64_, matrix_32_64_, {}));
298   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
299   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
300   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
301 
302   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
303   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
304 }
305 
TEST_F(ShapeInferenceTest,VariadicOpTuplify)306 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
307   StatusOr<Shape> result =
308       ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
309   ASSERT_IS_OK(result.status());
310   ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
311                                ShapeUtil::MakeTupleShape({s32_, f32_})));
312 }
313 
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)314 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
315   Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
316   Window window;
317   WindowDimension dim;
318   dim.set_size(2);
319   dim.set_stride(2);
320   dim.set_padding_low(0);
321   dim.set_padding_high(0);
322   dim.set_window_dilation(1);
323   dim.set_base_dilation(1);
324   *window.add_dimensions() = dim;
325   *window.add_dimensions() = dim;
326   Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
327   Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
328   Shape float_scalar = ShapeUtil::MakeShape(F32, {});
329   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
330       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
331   auto inferred_status = ShapeInference::InferReduceWindowShape(
332       matrix_shape, init_value_shape, window, to_apply);
333 
334   ASSERT_IS_OK(inferred_status.status());
335   Shape inferred = inferred_status.ValueOrDie();
336   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
337 }
338 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)339 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
340   auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
341       operand_shape_, select_program_shape_, window_, source_shape_,
342       init_value_shape_, scatter_program_shape_);
343   ASSERT_IS_OK(inferred_status_ok.status());
344   Shape inferred = inferred_status_ok.ValueOrDie();
345   ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
346 }
347 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)348 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
349   Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
350   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
351       operand_shape_, select_program_shape_, window_, source_shape_fail,
352       init_value_shape_, scatter_program_shape_);
353   ASSERT_FALSE(inferred_status_fail.ok());
354   ASSERT_THAT(inferred_status_fail.status().error_message(),
355               HasSubstr("Source shape does not match"));
356 }
357 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)358 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
359   ProgramShape select_program_shape_fail =
360       ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
361   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
362       operand_shape_, select_program_shape_fail, window_, source_shape_,
363       init_value_shape_, scatter_program_shape_);
364   ASSERT_FALSE(inferred_status_fail.ok());
365   ASSERT_THAT(inferred_status_fail.status().error_message(),
366               HasSubstr("Select function must take 2 parameters"));
367 }
368 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)369 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
370   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
371       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
372   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
373       operand_shape_, select_program_shape_fail, window_, source_shape_,
374       init_value_shape_, scatter_program_shape_);
375   ASSERT_FALSE(inferred_status_fail.ok());
376   ASSERT_THAT(inferred_status_fail.status().error_message(),
377               HasSubstr("Select function must have rank-0 PRED"));
378 }
379 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)380 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
381   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
382       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
383   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
384       operand_shape_, select_program_shape_fail, window_, source_shape_,
385       init_value_shape_, scatter_program_shape_);
386   ASSERT_FALSE(inferred_status_fail.ok());
387   ASSERT_THAT(inferred_status_fail.status().error_message(),
388               HasSubstr("Select function's first parameter"));
389 }
390 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)391 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
392   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
393       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
394   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
395       operand_shape_, select_program_shape_fail, window_, source_shape_,
396       init_value_shape_, scatter_program_shape_);
397   ASSERT_FALSE(inferred_status_fail.ok());
398   ASSERT_THAT(inferred_status_fail.status().error_message(),
399               HasSubstr("Select function's second parameter"));
400 }
401 
TEST_F(ShapeInferenceTest,Convolve)402 TEST_F(ShapeInferenceTest, Convolve) {
403   ConvolutionDimensionNumbers dnums;
404 
405   // Dimension order: batch, feature, x0, x1
406   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
407   dnums.set_input_batch_dimension(0);
408   dnums.set_output_batch_dimension(0);
409   dnums.set_input_feature_dimension(1);
410   dnums.set_output_feature_dimension(1);
411   dnums.add_input_spatial_dimensions(2);
412   dnums.add_output_spatial_dimensions(2);
413   dnums.add_input_spatial_dimensions(3);
414   dnums.add_output_spatial_dimensions(3);
415 
416   // Dimension order: x1, batch, feature, x0
417   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
418   dnums.set_kernel_input_feature_dimension(2);
419   dnums.set_kernel_output_feature_dimension(1);
420   dnums.add_kernel_spatial_dimensions(3);
421   dnums.add_kernel_spatial_dimensions(0);
422 
423   Window window;
424   auto dim0 = window.add_dimensions();
425   auto dim1 = window.add_dimensions();
426   dim0->set_size(3);
427   dim0->set_stride(2);
428   dim0->set_padding_low(1);
429   dim0->set_padding_high(1);
430   dim0->set_window_dilation(1);
431   dim0->set_base_dilation(1);
432   dim1->set_size(2);
433   dim1->set_stride(1);
434   dim1->set_padding_low(0);
435   dim1->set_padding_high(0);
436   dim1->set_window_dilation(1);
437   dim1->set_base_dilation(1);
438   auto inferred_status = ShapeInference::InferConvolveShape(
439       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
440       window, dnums, /*preferred_element_type=*/absl::nullopt);
441   ASSERT_IS_OK(inferred_status.status());
442   Shape inferred_shape = inferred_status.ValueOrDie();
443   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
444                                inferred_shape));
445 }
446 
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)447 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
448   ConvolutionDimensionNumbers dnums;
449 
450   // Dimension order: batch, feature, x0, x1
451   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
452   dnums.set_input_batch_dimension(0);
453   dnums.set_output_batch_dimension(0);
454   dnums.set_input_feature_dimension(1);
455   dnums.set_output_feature_dimension(1);
456   dnums.add_input_spatial_dimensions(2);
457   dnums.add_output_spatial_dimensions(2);
458   dnums.add_input_spatial_dimensions(3);
459   dnums.add_output_spatial_dimensions(3);
460 
461   // Dimension order: x1, batch, feature, x0
462   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
463   dnums.set_kernel_input_feature_dimension(2);
464   dnums.set_kernel_output_feature_dimension(1);
465   dnums.add_kernel_spatial_dimensions(3);
466   dnums.add_kernel_spatial_dimensions(0);
467 
468   Window window;
469   auto dim0 = window.add_dimensions();
470   dim0->set_size(3);
471   dim0->set_stride(3);
472   dim0->set_padding_low(0);
473   dim0->set_padding_high(0);
474   dim0->set_window_dilation(6);
475   dim0->set_base_dilation(1);
476 
477   auto dim1 = window.add_dimensions();
478   dim1->set_size(2);
479   dim1->set_stride(1);
480   dim1->set_padding_low(2);
481   dim1->set_padding_high(1);
482   dim1->set_window_dilation(2);
483   dim1->set_base_dilation(1);
484   auto inferred_status = ShapeInference::InferConvolveShape(
485       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
486       window, dnums, /*preferred_element_type=*/absl::nullopt);
487   ASSERT_IS_OK(inferred_status.status());
488   Shape inferred_shape = inferred_status.ValueOrDie();
489   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
490                                inferred_shape));
491 }
492 
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)493 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
494   ConvolutionDimensionNumbers dnums;
495 
496   // Dimension order: batch, feature, x0, x1
497   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
498   dnums.set_input_batch_dimension(0);
499   dnums.set_output_batch_dimension(0);
500   dnums.set_input_feature_dimension(1);
501   dnums.set_output_feature_dimension(1);
502   dnums.add_input_spatial_dimensions(2);
503   dnums.add_output_spatial_dimensions(2);
504   dnums.add_input_spatial_dimensions(3);
505   dnums.add_output_spatial_dimensions(3);
506 
507   // Dimension order: x1, batch, feature, x0
508   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
509   dnums.set_kernel_input_feature_dimension(2);
510   dnums.set_kernel_output_feature_dimension(1);
511   dnums.add_kernel_spatial_dimensions(3);
512   dnums.add_kernel_spatial_dimensions(0);
513 
514   Window window;
515   auto dim0 = window.add_dimensions();
516   dim0->set_size(4);
517   dim0->set_stride(3);
518   dim0->set_padding_low(0);
519   dim0->set_padding_high(0);
520   dim0->set_window_dilation(1);
521   dim0->set_base_dilation(6);
522 
523   auto dim1 = window.add_dimensions();
524   dim1->set_size(2);
525   dim1->set_stride(1);
526   dim1->set_padding_low(2);
527   dim1->set_padding_high(1);
528   dim1->set_window_dilation(1);
529   dim1->set_base_dilation(2);
530   auto inferred_status = ShapeInference::InferConvolveShape(
531       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
532       window, dnums, /*preferred_element_type=*/absl::nullopt);
533   ASSERT_IS_OK(inferred_status.status());
534   Shape inferred_shape = inferred_status.ValueOrDie();
535   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
536                                inferred_shape));
537 }
538 
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)539 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
540   // Dimension order for this test: batch, feature, x0, x1
541   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
542   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
543 
544   ConvolutionDimensionNumbers dnums;
545   dnums.set_input_batch_dimension(3);
546   dnums.set_output_batch_dimension(3);
547   dnums.set_input_feature_dimension(2);
548   dnums.set_output_feature_dimension(2);
549   dnums.add_input_spatial_dimensions(0);
550   dnums.add_output_spatial_dimensions(0);
551   dnums.add_input_spatial_dimensions(1);
552   dnums.add_output_spatial_dimensions(1);
553   dnums.set_kernel_input_feature_dimension(0);  // duplicated with kernel_x0
554   dnums.set_kernel_output_feature_dimension(3);
555   dnums.add_kernel_spatial_dimensions(0);
556   dnums.add_kernel_spatial_dimensions(1);
557 
558   Window window;
559   auto dim0 = window.add_dimensions();
560   auto dim1 = window.add_dimensions();
561   dim0->set_size(2);
562   dim0->set_stride(1);
563   dim0->set_padding_low(0);
564   dim0->set_padding_high(0);
565   dim1->set_size(3);
566   dim1->set_stride(2);
567   dim1->set_padding_low(1);
568   dim1->set_padding_high(1);
569   auto inferred_status = ShapeInference::InferConvolveShape(
570       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
571       window, dnums, /*preferred_element_type=*/absl::nullopt);
572   ASSERT_FALSE(inferred_status.ok());
573   ASSERT_THAT(inferred_status.status().error_message(),
574               HasSubstr("each dimension exactly once"));
575 }
576 
TEST_F(ShapeInferenceTest,ConvolveBatchGroupCountUnequalOutputFeature)577 TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
578   ConvolutionDimensionNumbers dnums;
579   dnums.set_input_batch_dimension(0);
580   dnums.set_input_feature_dimension(1);
581   dnums.add_input_spatial_dimensions(2);
582   dnums.add_input_spatial_dimensions(3);
583   dnums.set_kernel_input_feature_dimension(0);
584   dnums.set_kernel_output_feature_dimension(1);
585   dnums.add_kernel_spatial_dimensions(2);
586   dnums.add_kernel_spatial_dimensions(3);
587   dnums.set_output_batch_dimension(0);
588   dnums.set_output_feature_dimension(1);
589   dnums.add_output_spatial_dimensions(2);
590   dnums.add_output_spatial_dimensions(3);
591   Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
592   Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
593   Window window;
594   auto dim0 = window.add_dimensions();
595   auto dim1 = window.add_dimensions();
596   dim0->set_size(4);
597   dim1->set_size(4);
598   dim0->set_padding_low(0);
599   dim0->set_padding_high(2);
600   dim1->set_padding_low(2);
601   dim1->set_padding_high(1);
602   dim0->set_stride(1);
603   dim1->set_stride(1);
604   dim0->set_window_dilation(3);
605   dim1->set_window_dilation(2);
606   auto inferred_status = ShapeInference::InferConvolveShape(
607       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
608       window, dnums, /*preferred_element_type=*/absl::nullopt);
609   ASSERT_FALSE(inferred_status.ok());
610   ASSERT_THAT(inferred_status.status().error_message(),
611               HasSubstr("to be a multiple of batch group count"));
612 }
613 
614 struct ConvolveArgs {
615   Shape lhs_shape;
616   Shape rhs_shape;
617   ConvolutionDimensionNumbers dnums;
618   Window window;
619 };
620 
MakeConvolveArgs(PrimitiveType lhs_type,PrimitiveType rhs_type)621 ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
622   ConvolveArgs args;
623   ConvolutionDimensionNumbers& dnums = args.dnums;
624 
625   // Dimension order: batch, feature, x0, x1
626   args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
627   dnums.set_input_batch_dimension(0);
628   dnums.set_output_batch_dimension(0);
629   dnums.set_input_feature_dimension(1);
630   dnums.set_output_feature_dimension(1);
631   dnums.add_input_spatial_dimensions(2);
632   dnums.add_output_spatial_dimensions(2);
633   dnums.add_input_spatial_dimensions(3);
634   dnums.add_output_spatial_dimensions(3);
635 
636   // Dimension order: x1, batch, feature, x0
637   args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
638   dnums.set_kernel_input_feature_dimension(2);
639   dnums.set_kernel_output_feature_dimension(1);
640   dnums.add_kernel_spatial_dimensions(3);
641   dnums.add_kernel_spatial_dimensions(0);
642 
643   auto dim0 = args.window.add_dimensions();
644   auto dim1 = args.window.add_dimensions();
645   dim0->set_size(3);
646   dim0->set_stride(2);
647   dim0->set_padding_low(1);
648   dim0->set_padding_high(1);
649   dim0->set_window_dilation(1);
650   dim0->set_base_dilation(1);
651   dim1->set_size(2);
652   dim1->set_stride(1);
653   dim1->set_padding_low(0);
654   dim1->set_padding_high(0);
655   dim1->set_window_dilation(1);
656   dim1->set_base_dilation(1);
657   return args;
658 }
659 
TEST_F(ShapeInferenceTest,ConvolveWithBF16_F16)660 TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
661   ConvolveArgs args = MakeConvolveArgs(BF16, F16);
662   TF_ASSERT_OK_AND_ASSIGN(
663       Shape inferred_shape,
664       ShapeInference::InferConvolveShape(
665           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
666           /*batch_group_count=*/1, args.window, args.dnums,
667           /*preferred_element_type=*/absl::nullopt))
668   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
669                                inferred_shape));
670 }
671 
TEST_F(ShapeInferenceTest,ConvolveWithF16_BF16)672 TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
673   ConvolveArgs args = MakeConvolveArgs(F16, BF16);
674   TF_ASSERT_OK_AND_ASSIGN(
675       Shape inferred_shape,
676       ShapeInference::InferConvolveShape(
677           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
678           /*batch_group_count=*/1, args.window, args.dnums,
679           /*preferred_element_type=*/absl::nullopt))
680   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
681                                inferred_shape));
682 }
683 
TEST_F(ShapeInferenceTest,ConvolveWithS32_U32)684 TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
685   ConvolveArgs args = MakeConvolveArgs(S32, U32);
686   TF_ASSERT_OK_AND_ASSIGN(
687       Shape inferred_shape,
688       ShapeInference::InferConvolveShape(
689           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
690           /*batch_group_count=*/1, args.window, args.dnums,
691           /*preferred_element_type=*/absl::nullopt))
692   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
693                                inferred_shape));
694 }
695 
TEST_F(ShapeInferenceTest,ConvolveWithU32_S32)696 TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
697   ConvolveArgs args = MakeConvolveArgs(U32, S32);
698   TF_ASSERT_OK_AND_ASSIGN(
699       Shape inferred_shape,
700       ShapeInference::InferConvolveShape(
701           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
702           /*batch_group_count=*/1, args.window, args.dnums,
703           /*preferred_element_type=*/absl::nullopt))
704   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
705                                inferred_shape));
706 }
707 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementType)708 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
709   ConvolveArgs args = MakeConvolveArgs(S8, S16);
710   TF_ASSERT_OK_AND_ASSIGN(
711       Shape inferred_shape,
712       ShapeInference::InferConvolveShape(
713           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
714           /*batch_group_count=*/1, args.window, args.dnums,
715           /*preferred_element_type=*/S16))
716   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
717                                inferred_shape));
718 }
719 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeSameAsInferredType)720 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
721   ConvolveArgs args = MakeConvolveArgs(S8, S16);
722   TF_ASSERT_OK_AND_ASSIGN(
723       Shape inferred_shape,
724       ShapeInference::InferConvolveShape(
725           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
726           /*batch_group_count=*/1, args.window, args.dnums,
727           /*preferred_element_type=*/S32))
728   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
729                                inferred_shape));
730 }
731 
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithNarrowerPreferredElementType)732 TEST_F(ShapeInferenceTest,
733        FloatingPointConvolveWithNarrowerPreferredElementType) {
734   ConvolveArgs args = MakeConvolveArgs(F32, F32);
735   TF_ASSERT_OK_AND_ASSIGN(
736       Shape inferred_shape,
737       ShapeInference::InferConvolveShape(
738           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
739           /*batch_group_count=*/1, args.window, args.dnums,
740           /*preferred_element_type=*/BF16))
741   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
742                                inferred_shape));
743 }
744 
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithInvalidPreferredElementType)745 TEST_F(ShapeInferenceTest,
746        FloatingPointConvolveWithInvalidPreferredElementType) {
747   ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
748   auto inferred_status =
749       ShapeInference::InferConvolveShape(
750           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
751           /*batch_group_count=*/1, args.window, args.dnums,
752           /*preferred_element_type=*/S32)
753           .status();
754   ASSERT_FALSE(inferred_status.ok());
755   ASSERT_THAT(inferred_status.error_message(),
756               HasSubstr("must both be integral or both be floating point"));
757 }
758 
TEST_F(ShapeInferenceTest,IntegralConvolveWithFloatingPointPreferredElementType)759 TEST_F(ShapeInferenceTest,
760        IntegralConvolveWithFloatingPointPreferredElementType) {
761   ConvolveArgs args = MakeConvolveArgs(S8, S16);
762   auto inferred_status =
763       ShapeInference::InferConvolveShape(
764           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
765           /*batch_group_count=*/1, args.window, args.dnums,
766           /*preferred_element_type=*/F32)
767           .status();
768   ASSERT_FALSE(inferred_status.ok());
769   ASSERT_THAT(inferred_status.error_message(),
770               HasSubstr("must both be integral or both be floating point"));
771 }
772 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeWithDifferentSignedness)773 TEST_F(ShapeInferenceTest,
774        ConvolveWithPreferredElementTypeWithDifferentSignedness) {
775   ConvolveArgs args = MakeConvolveArgs(S8, S16);
776   auto inferred_status =
777       ShapeInference::InferConvolveShape(
778           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
779           /*batch_group_count=*/1, args.window, args.dnums,
780           /*preferred_element_type=*/U32)
781           .status();
782   ASSERT_FALSE(inferred_status.ok());
783   ASSERT_THAT(inferred_status.error_message(),
784               HasSubstr("must have the same signedness as the original type"));
785 }
786 
TEST_F(ShapeInferenceTest,ConvolveWithNarrowerPreferredElementType)787 TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
788   ConvolveArgs args = MakeConvolveArgs(S8, S16);
789   auto inferred_status =
790       ShapeInference::InferConvolveShape(
791           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
792           /*batch_group_count=*/1, args.window, args.dnums,
793           /*preferred_element_type=*/S8)
794           .status();
795   ASSERT_FALSE(inferred_status.ok());
796   ASSERT_THAT(inferred_status.error_message(),
797               HasSubstr("must not be narrower than the original type"));
798 }
799 
800 namespace fft {
801 
802 static const char* unsupported_rank = "only supports ranks 1-3";
803 static const char* invalid_rank = "requires input of at least same rank";
804 static const char* requires_complex_input = "requires complex input type";
805 static const char* requires_f32_input = "requires F32 or F64 input type";
806 static const char* dimensions_match = "innermost dimensions match fft_length";
807 static const char* innermost_dimension_matches =
808     "innermost dimension matches fft_length/2+1";
809 
Pass(const Shape & shape,FftType type,absl::Span<const int64> length,const Shape & expected_shape)810 static void Pass(const Shape& shape, FftType type,
811                  absl::Span<const int64> length, const Shape& expected_shape) {
812   auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
813   ASSERT_IS_OK(inferred_status.status());
814   Shape inferred_shape = inferred_status.ValueOrDie();
815   ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
816 }
817 
Fail(const Shape & shape,FftType type,absl::Span<const int64> length,absl::string_view message)818 static void Fail(const Shape& shape, FftType type,
819                  absl::Span<const int64> length, absl::string_view message) {
820   auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
821   ASSERT_FALSE(inferred_status.ok());
822   ASSERT_THAT(inferred_status.status().error_message(),
823               HasSubstr(std::string(message)));
824 }
825 
826 }  // namespace fft
827 
TEST_F(ShapeInferenceTest,InferFftShapeTestFftRanks)828 TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
829   FftType type = FftType::FFT;
830   Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
831   fft::Fail(shape, type, {}, fft::unsupported_rank);
832   fft::Pass(shape, type, {8}, shape);
833   fft::Pass(shape, type, {16, 8}, shape);
834   fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
835   fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
836 }
837 
TEST_F(ShapeInferenceTest,InferFftShapeTestFftTypes)838 TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
839   FftType type = FftType::FFT;
840   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
841   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
842   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
843   fft::Pass(shape_c128, type, {16, 8}, shape_c128);
844 }
845 
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftRanks)846 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
847   FftType type = FftType::IFFT;
848   Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
849   fft::Fail(shape, type, {}, fft::unsupported_rank);
850   fft::Pass(shape, type, {8}, shape);
851   fft::Pass(shape, type, {16, 8}, shape);
852   fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
853   fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
854 }
855 
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftTypes)856 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
857   FftType type = FftType::IFFT;
858   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
859   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
860   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
861   fft::Pass(shape_c128, type, {16, 8}, shape_c128);
862 }
863 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftRanks)864 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
865   FftType type = FftType::RFFT;
866   Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
867   Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
868   fft::Fail(shape_in, type, {}, fft::unsupported_rank);
869   fft::Pass(shape_in, type, {8}, shape_out);
870   fft::Pass(shape_in, type, {16, 8}, shape_out);
871   fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
872   fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
873 }
874 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftDimensions)875 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
876   FftType type = FftType::RFFT;
877   Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
878   fft::Fail(shape, type, {4}, fft::dimensions_match);
879   fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
880   fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
881   fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
882 
883   Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
884   Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
885   fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
886   fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
887 
888   Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
889   Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
890   Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
891   fft::Pass(even_shape_in, type, {16, 8}, shape_out);
892   fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
893 }
894 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftTypes)895 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
896   FftType type = FftType::RFFT;
897   Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
898   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
899   fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
900   fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
901 }
902 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftRanks)903 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
904   FftType type = FftType::IRFFT;
905   Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
906   Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
907   fft::Fail(shape_in, type, {}, fft::unsupported_rank);
908   fft::Pass(shape_in, type, {8}, shape_out);
909   fft::Pass(shape_in, type, {16, 8}, shape_out);
910   fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
911   fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
912 }
913 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftDimensions)914 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
915   FftType type = FftType::IRFFT;
916   Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
917   fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
918   fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
919   fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
920   fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
921 
922   Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
923   Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
924   fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
925   fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
926 
927   Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
928   Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
929   fft::Pass(shape, type, {16, 8}, even_shape_out);
930   fft::Pass(shape, type, {16, 9}, odd_shape_out);
931 }
932 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftTypes)933 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
934   FftType type = FftType::IRFFT;
935   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
936   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
937   Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
938   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
939   fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
940 }
941 
TEST_F(ShapeInferenceTest,MapThatChangesElementType)942 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
943   Shape arg = ShapeUtil::MakeShape(F32, {20});
944   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
945   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
946   EXPECT_IS_OK(inferred_status.status());
947   Shape expected = ShapeUtil::MakeShape(S32, {20});
948   EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
949 }
950 
TEST_F(ShapeInferenceTest,Map)951 TEST_F(ShapeInferenceTest, Map) {
952   auto inferred_status_r1f32 = ShapeInference::InferMapShape(
953       {&vector_32_, &vector_32_},
954       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
955   EXPECT_IS_OK(inferred_status_r1f32.status());
956   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
957 
958   // It's OK to provide a single argument, as long as the applied arity matches
959   // (this degenerates to a Map).
960   auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
961       {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
962   EXPECT_IS_OK(inferred_status_r1f32_one.status());
963   EXPECT_TRUE(
964       ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
965 
966   auto inferred_status_r2s32 = ShapeInference::InferMapShape(
967       {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
968       ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
969   EXPECT_IS_OK(inferred_status_r2s32.status());
970   EXPECT_TRUE(
971       ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
972 
973   auto no_args_error = ShapeInference::InferMapShape(
974       {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
975   ASSERT_FALSE(no_args_error.ok());
976   ASSERT_THAT(no_args_error.status().error_message(),
977               HasSubstr("expects at least one argument"));
978 
979   auto args_diff_shapes_error = ShapeInference::InferMapShape(
980       {&vector_32_, &vector_64_},
981       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
982   ASSERT_FALSE(args_diff_shapes_error.ok());
983   ASSERT_THAT(args_diff_shapes_error.status().error_message(),
984               HasSubstr("requires all operands to have the same shape"));
985 
986   auto arity_error = ShapeInference::InferMapShape(
987       {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
988       {0});
989   ASSERT_FALSE(arity_error.ok());
990   ASSERT_THAT(arity_error.status().error_message(),
991               HasSubstr("function arity must match"));
992 
993   auto output_shape_error = ShapeInference::InferMapShape(
994       {&vector_32_, &vector_32_},
995       ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
996   ASSERT_FALSE(output_shape_error.ok());
997   ASSERT_THAT(output_shape_error.status().error_message(),
998               HasSubstr("result has to be a scalar"));
999 
1000   auto param_shape_error = ShapeInference::InferMapShape(
1001       {&vector_32_, &vector_32_},
1002       ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
1003   ASSERT_FALSE(param_shape_error.ok());
1004   ASSERT_THAT(param_shape_error.status().error_message(),
1005               HasSubstr("parameter has to be a scalar"));
1006 
1007   auto param_element_type_error = ShapeInference::InferMapShape(
1008       {&vector_32_, &vector_32_},
1009       ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
1010   ASSERT_FALSE(param_element_type_error.ok());
1011   ASSERT_THAT(param_element_type_error.status().error_message(),
1012               HasSubstr("parameter type has to match argument"));
1013 
1014   Shape arg = ShapeUtil::MakeShape(F32, {20});
1015   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
1016   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
1017   EXPECT_IS_OK(inferred_status.status());
1018   EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
1019 
1020   auto inferred_status_error1 = ShapeInference::InferMapShape(
1021       {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1022   ASSERT_FALSE(inferred_status_error1.ok());
1023   ASSERT_THAT(inferred_status_error1.status().error_message(),
1024               HasSubstr("arity must match number of arguments"));
1025 
1026   auto inferred_status_error2 = ShapeInference::InferMapShape(
1027       {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
1028   ASSERT_FALSE(inferred_status_error2.ok());
1029   ASSERT_THAT(inferred_status_error2.status().error_message(),
1030               HasSubstr("has to be a scalar"));
1031 
1032   auto inferred_status_error3 = ShapeInference::InferMapShape(
1033       {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
1034   ASSERT_FALSE(inferred_status_error3.ok());
1035   ASSERT_THAT(inferred_status_error3.status().error_message(),
1036               HasSubstr("has to be a scalar"));
1037 
1038   auto inferred_status_error5 = ShapeInference::InferMapShape(
1039       {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
1040   ASSERT_FALSE(inferred_status_error5.ok());
1041   ASSERT_THAT(inferred_status_error5.status().error_message(),
1042               HasSubstr("parameter type has to match argument"));
1043 }
1044 
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)1045 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
1046   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
1047                             /*dimensions_to_reduce=*/{0});
1048 }
1049 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)1050 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
1051   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
1052                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1053                             /*dimensions_to_reduce=*/{0});
1054 }
1055 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)1056 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
1057   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
1058                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1059                             /*dimensions_to_reduce=*/{1});
1060 }
1061 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)1062 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
1063   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
1064                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1065                             /*dimensions_to_reduce=*/{0, 1});
1066 }
1067 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)1068 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
1069   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
1070                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1071                             /*dimensions_to_reduce=*/{1, 2});
1072 }
1073 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)1074 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
1075   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1076                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1077                             /*dimensions_to_reduce=*/{0, 2});
1078 
1079   // Check that the order of dimensions_to_reduce doesn't matter.
1080   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1081                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1082                             /*dimensions_to_reduce=*/{2, 0});
1083 }
1084 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)1085 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
1086   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
1087                             /*dimensions_to_reduce=*/{0, 1, 2});
1088 }
1089 
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)1090 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
1091   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1092   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1093   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1094       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1095   auto inferred_status = ShapeInference::InferReduceShape(
1096       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1097   EXPECT_IS_OK(inferred_status.status());
1098   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
1099                                inferred_status.ValueOrDie()));
1100 }
1101 
TEST_F(ReduceShapeInferenceTest,ReduceWindowMultiOutput)1102 TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
1103   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1104   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1105   std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1106   std::vector<const Shape*> inits = {&f32_, &s32_};
1107   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1108       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1109   std::vector<int64> window_dimensions = {1, 2, 4};
1110   std::vector<int64> window_strides = {1, 1, 1};
1111   std::vector<std::pair<int64, int64>> padding_values =
1112       MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1113                   window_strides, Padding::kValid);
1114   TF_ASSERT_OK_AND_ASSIGN(
1115       Window window,
1116       ShapeInference::InferWindowFromDimensions(
1117           window_dimensions, window_strides, padding_values, {}, {}));
1118   auto inferred_status = ShapeInference::InferReduceWindowShape(
1119       absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1120   VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
1121   EXPECT_IS_OK(inferred_status.status());
1122   EXPECT_TRUE(ShapeUtil::Equal(
1123       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
1124                                  ShapeUtil::MakeShape(S32, {5, 2, 0})}),
1125       inferred_status.ValueOrDie()));
1126 }
1127 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)1128 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
1129   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1130   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1131   ProgramShape to_apply =
1132       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
1133                                   ShapeUtil::MakeTupleShape({f32_, s32_}));
1134   auto inferred_status = ShapeInference::InferReduceShape(
1135       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1136   EXPECT_FALSE(inferred_status.ok());
1137   EXPECT_THAT(inferred_status.status().error_message(),
1138               HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
1139 }
1140 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)1141 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
1142   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1143   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1144   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1145       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1146   auto inferred_status = ShapeInference::InferReduceShape(
1147       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1148   EXPECT_FALSE(inferred_status.ok());
1149   EXPECT_THAT(
1150       inferred_status.status().error_message(),
1151       HasSubstr(
1152           "parameter shape differs from the result shape: s32[] vs f32[]"));
1153 }
1154 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)1155 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
1156   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1157       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1158   auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
1159   EXPECT_FALSE(inferred_status.ok());
1160   EXPECT_THAT(inferred_status.status().error_message(),
1161               HasSubstr("must have at least 2 arguments, has 0"));
1162 }
1163 
TEST_F(ReduceShapeInferenceTest,ErrorBadReduceWindowInput)1164 TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
1165   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1166   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1167   std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1168   std::vector<const Shape*> inits = {&f32_, &s32_};
1169   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1170       {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1171   std::vector<int64> window_dimensions = {1, 2, 4};
1172   std::vector<int64> window_strides = {1, 1, 1};
1173   std::vector<std::pair<int64, int64>> padding_values =
1174       MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1175                   window_strides, Padding::kValid);
1176   TF_ASSERT_OK_AND_ASSIGN(
1177       Window window,
1178       ShapeInference::InferWindowFromDimensions(
1179           window_dimensions, window_strides, padding_values, {}, {}));
1180   auto inferred_status = ShapeInference::InferReduceWindowShape(
1181       absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1182   EXPECT_FALSE(inferred_status.status().ok());
1183   EXPECT_THAT(inferred_status.status().error_message(),
1184               HasSubstr("f32[] vs s32[]"));
1185 }
1186 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)1187 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
1188   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1189   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1190   ProgramShape to_apply =
1191       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
1192   auto inferred_status = ShapeInference::InferReduceShape(
1193       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1194   EXPECT_FALSE(inferred_status.ok());
1195   EXPECT_THAT(
1196       inferred_status.status().error_message(),
1197       HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
1198 }
1199 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)1200 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
1201   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1202   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1203   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1204       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
1205   auto inferred_status = ShapeInference::InferReduceShape(
1206       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1207   EXPECT_FALSE(inferred_status.ok());
1208   EXPECT_THAT(
1209       inferred_status.status().error_message(),
1210       HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
1211 }
1212 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)1213 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
1214   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1215   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1216   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1217       {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
1218   auto inferred_status = ShapeInference::InferReduceShape(
1219       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1220   EXPECT_FALSE(inferred_status.ok());
1221   EXPECT_THAT(inferred_status.status().error_message(),
1222               HasSubstr("accumulator shape at index 0 differs from the "
1223                         "init_value shape: s32[] vs f32[]"));
1224 }
1225 
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)1226 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
1227   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1228   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1229   auto inferred_status = ShapeInference::InferReduceShape(
1230       {&arg_shape, &f32_},
1231       /*dimensions_to_reduce=*/{3, 4}, to_apply);
1232   EXPECT_FALSE(inferred_status.ok());
1233   EXPECT_THAT(inferred_status.status().error_message(),
1234               HasSubstr("out-of-bounds dimension"));
1235 }
1236 
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)1237 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
1238   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
1239   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1240   auto inferred_status =
1241       ShapeInference::InferReduceShape({&arg_shape, &f32_},
1242                                        /*dimensions_to_reduce=*/{0}, to_apply);
1243   EXPECT_FALSE(inferred_status.ok());
1244   EXPECT_THAT(inferred_status.status().error_message(),
1245               HasSubstr("take 2 parameters"));
1246 }
1247 
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)1248 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
1249   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
1250   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1251   auto inferred_status =
1252       ShapeInference::InferReduceShape({&arg_shape, &f32_},
1253                                        /*dimensions_to_reduce=*/{0}, to_apply);
1254   EXPECT_FALSE(inferred_status.ok());
1255   EXPECT_THAT(inferred_status.status().error_message(),
1256               HasSubstr("0-th parameter shape differs"));
1257 }
1258 
TEST_F(ReduceShapeInferenceTest,ReduceWithRepeatedReduceDimension)1259 TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
1260   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1261   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1262   auto inferred_status = ShapeInference::InferReduceShape(
1263       {&arg_shape, &f32_},
1264       /*dimensions_to_reduce=*/{0, 0}, to_apply);
1265   EXPECT_FALSE(inferred_status.ok());
1266   EXPECT_THAT(inferred_status.status().error_message(),
1267               HasSubstr("Duplicate reduction dimension: 0"));
1268 }
1269 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)1270 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
1271   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1272   auto inferred_status =
1273       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
1274   ASSERT_IS_OK(inferred_status.status());
1275   Shape inferred = inferred_status.ValueOrDie();
1276   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
1277 }
1278 
TEST_F(ShapeInferenceTest,InferSliceWithDynamicDimensions)1279 TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
1280   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
1281   auto inferred_status =
1282       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
1283   ASSERT_IS_OK(inferred_status.status());
1284   Shape inferred = inferred_status.ValueOrDie();
1285   ASSERT_TRUE(ShapeUtil::Equal(
1286       ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
1287 }
1288 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)1289 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
1290   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1291   auto inferred_status =
1292       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
1293   ASSERT_IS_OK(inferred_status.status());
1294   Shape inferred = inferred_status.ValueOrDie();
1295   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
1296 }
1297 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)1298 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
1299   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1300   auto inferred_status =
1301       ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
1302   ASSERT_IS_OK(inferred_status.status());
1303   Shape inferred = inferred_status.ValueOrDie();
1304   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
1305 }
1306 
TEST_F(ShapeInferenceTest,InferInvalidStride)1307 TEST_F(ShapeInferenceTest, InferInvalidStride) {
1308   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1309   auto inferred_status =
1310       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
1311   ASSERT_FALSE(inferred_status.ok());
1312   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1313             inferred_status.status().code());
1314 }
1315 
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)1316 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
1317   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1318   auto inferred_status =
1319       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
1320   ASSERT_FALSE(inferred_status.ok());
1321   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1322             inferred_status.status().code());
1323 }
1324 
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)1325 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
1326   Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
1327   auto inferred_status =
1328       ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
1329   ASSERT_TRUE(inferred_status.ok());
1330   Shape inferred = inferred_status.ValueOrDie();
1331   ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
1332 }
1333 
TEST_F(ShapeInferenceTest,InferConstIndexShape)1334 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
1335   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1336   auto inferred0_status =
1337       ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
1338   auto inferred1_status =
1339       ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
1340   ASSERT_IS_OK(inferred0_status.status());
1341   ASSERT_IS_OK(inferred1_status.status());
1342   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
1343   ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
1344 }
1345 
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)1346 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
1347   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1348   auto inferredNegative_status =
1349       ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
1350   auto inferred2_status =
1351       ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
1352   ASSERT_FALSE(inferredNegative_status.ok());
1353   ASSERT_FALSE(inferred2_status.ok());
1354   EXPECT_THAT(inferredNegative_status.status().error_message(),
1355               HasSubstr("attempt to index out of tuple bounds"));
1356   EXPECT_THAT(inferred2_status.status().error_message(),
1357               HasSubstr("attempt to index out of tuple bounds"));
1358 }
1359 
TEST_F(ShapeInferenceTest,InferPowShape)1360 TEST_F(ShapeInferenceTest, InferPowShape) {
1361   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1362   auto inferred_status = ShapeInference::InferBinaryOpShape(
1363       HloOpcode::kPower, ten_floats, f32_, {});
1364   ASSERT_IS_OK(inferred_status.status());
1365   ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
1366 }
1367 
TEST_F(ShapeInferenceTest,InferCompareShape)1368 TEST_F(ShapeInferenceTest, InferCompareShape) {
1369   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1370   auto inferred_status = ShapeInference::InferBinaryOpShape(
1371       HloOpcode::kCompare, ten_floats, f32_, {});
1372   ASSERT_IS_OK(inferred_status.status());
1373   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
1374                                inferred_status.ValueOrDie()));
1375 }
1376 
TEST_F(ShapeInferenceTest,InferReshapeDegenerateCombine)1377 TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
1378   // [1, <=1]
1379   //   | reshape
1380   // [<=1]
1381   //
1382   // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1383   auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
1384   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
1385                                                   /*inferred_dimension=*/-1);
1386   ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
1387 }
1388 
TEST_F(ShapeInferenceTest,InferReshapeSplit)1389 TEST_F(ShapeInferenceTest, InferReshapeSplit) {
1390   // [<=10]
1391   //   | reshape
1392   // [1, 10]
1393   //
1394   // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1395   auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
1396   auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
1397                                                   /*inferred_dimension=*/0);
1398   ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
1399             status.ValueOrDie());
1400 }
1401 
TEST_F(ShapeInferenceTest,InferReshapeCombine)1402 TEST_F(ShapeInferenceTest, InferReshapeCombine) {
1403   // [6, <=10]
1404   //   | reshape
1405   // [<=60]
1406   auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1407   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
1408                                                   /*inferred_dimension=*/-11);
1409   ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
1410 }
1411 
TEST_F(ShapeInferenceTest,UnchangedDimension)1412 TEST_F(ShapeInferenceTest, UnchangedDimension) {
1413   // [6, <=10]
1414   //   | reshape
1415   // [2, 3, <=10]
1416   auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1417   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
1418                                                   /*inferred_dimension=*/-11);
1419   ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
1420             status.ValueOrDie());
1421 }
1422 
TEST_F(ShapeInferenceTest,InferDynamicBroadcast)1423 TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
1424   // CHECK:
1425   // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
1426 
1427   auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
1428   auto inferred_status =
1429       ShapeInference::InferBroadcastShape(operand_shape, {15});
1430   ASSERT_IS_OK(inferred_status.status());
1431   Shape inferred = inferred_status.ValueOrDie();
1432   ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
1433 }
1434 
TEST_F(ShapeInferenceTest,BroadcastScalar)1435 TEST_F(ShapeInferenceTest, BroadcastScalar) {
1436   for (auto element_type : {F32, U32, S8}) {
1437     const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
1438     {  // no-op scalar broadcast
1439       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
1440       ASSERT_IS_OK(status.status());
1441       ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
1442     }
1443     const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
1444     {  // scalar -> 1d broadcast
1445       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
1446       ASSERT_IS_OK(status.status());
1447       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1448     }
1449     {  // no-op 1d broadcast
1450       auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
1451       ASSERT_IS_OK(status.status());
1452       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1453     }
1454     const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
1455     {  // scalar -> 2d broadcast
1456       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
1457       ASSERT_IS_OK(status.status());
1458       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1459     }
1460     {  // 1d -> 2d broadcast
1461       auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
1462       ASSERT_IS_OK(status.status());
1463       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1464     }
1465   }
1466 }
1467 
1468 // scalar <dot> vector: ok
TEST_F(ShapeInferenceTest,ScalarDotVector)1469 TEST_F(ShapeInferenceTest, ScalarDotVector) {
1470   DotDimensionNumbers dot_dnums;
1471   auto inferred_status = ShapeInference::InferDotOpShape(
1472       f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
1473   EXPECT_TRUE(inferred_status.ok());
1474   EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
1475 }
1476 
1477 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)1478 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
1479   DotDimensionNumbers dot_dnums;
1480   dot_dnums.add_lhs_contracting_dimensions(1);
1481   dot_dnums.add_rhs_contracting_dimensions(0);
1482   auto inferred_status = ShapeInference::InferDotOpShape(
1483       ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
1484       /*preferred_element_type=*/absl::nullopt);
1485   EXPECT_TRUE(inferred_status.ok());
1486   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1487                                ShapeUtil::MakeShape(F32, {32, 32, 64})));
1488 }
1489 
1490 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)1491 TEST_F(ShapeInferenceTest, VectorDotVector) {
1492   DotDimensionNumbers dot_dnums;
1493   dot_dnums.add_lhs_contracting_dimensions(0);
1494   dot_dnums.add_rhs_contracting_dimensions(0);
1495   auto inferred_status =
1496       ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
1497                                       /*preferred_element_type=*/absl::nullopt);
1498   ASSERT_IS_OK(inferred_status.status());
1499   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
1500   auto inferred_status_mismatch =
1501       ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
1502                                       /*preferred_element_type=*/absl::nullopt);
1503   ASSERT_FALSE(inferred_status_mismatch.ok());
1504 }
1505 
1506 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1507 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1508   DotDimensionNumbers dot_dnums;
1509   dot_dnums.add_lhs_contracting_dimensions(1);
1510   dot_dnums.add_rhs_contracting_dimensions(0);
1511   auto inferred_status =
1512       ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
1513                                       /*preferred_element_type=*/absl::nullopt);
1514   ASSERT_IS_OK(inferred_status.status());
1515   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1516   auto inferred_status_mismatch =
1517       ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
1518                                       /*preferred_element_type=*/absl::nullopt);
1519   ASSERT_FALSE(inferred_status_mismatch.ok());
1520 }
1521 
1522 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1523 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1524   DotDimensionNumbers dot_dnums;
1525   dot_dnums.add_lhs_contracting_dimensions(0);
1526   dot_dnums.add_rhs_contracting_dimensions(0);
1527   auto inferred_status =
1528       ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
1529                                       /*preferred_element_type=*/absl::nullopt);
1530   ASSERT_IS_OK(inferred_status.status());
1531   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1532   auto inferred_status_mismatch =
1533       ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
1534                                       /*preferred_element_type=*/absl::nullopt);
1535   ASSERT_FALSE(inferred_status_mismatch.ok());
1536 }
1537 
1538 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1539 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1540   DotDimensionNumbers dot_dnums;
1541   dot_dnums.add_lhs_contracting_dimensions(1);
1542   dot_dnums.add_rhs_contracting_dimensions(0);
1543   auto inferred_status_match =
1544       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
1545                                       /*preferred_element_type=*/absl::nullopt);
1546   ASSERT_IS_OK(inferred_status_match.status());
1547   ASSERT_TRUE(
1548       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1549       << "inferred: "
1550       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1551       << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1552   auto inferred_status_mismatch =
1553       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
1554                                       /*preferred_element_type=*/absl::nullopt);
1555   ASSERT_FALSE(inferred_status_mismatch.ok());
1556 }
1557 
1558 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1559 TEST_F(ShapeInferenceTest, DotGeneral) {
1560   Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1561   Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1562   Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1563 
1564   DotDimensionNumbers dot_dnums;
1565   dot_dnums.add_lhs_contracting_dimensions(3);
1566   dot_dnums.add_lhs_batch_dimensions(0);
1567   dot_dnums.add_lhs_batch_dimensions(1);
1568 
1569   dot_dnums.add_rhs_contracting_dimensions(2);
1570   dot_dnums.add_rhs_batch_dimensions(0);
1571   dot_dnums.add_rhs_batch_dimensions(1);
1572 
1573   auto inferred_status_match =
1574       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1575                                       /*preferred_element_type=*/absl::nullopt);
1576   ASSERT_IS_OK(inferred_status_match.status());
1577   ASSERT_TRUE(
1578       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1579       << "inferred: "
1580       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1581       << " expected: " << ShapeUtil::HumanString(output_shape);
1582 }
1583 
1584 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1585 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1586   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1587   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1588 
1589   DotDimensionNumbers dot_dnums;
1590   dot_dnums.add_lhs_contracting_dimensions(2);
1591   dot_dnums.add_lhs_contracting_dimensions(3);
1592   dot_dnums.add_lhs_batch_dimensions(0);
1593 
1594   dot_dnums.add_rhs_contracting_dimensions(1);
1595   dot_dnums.add_rhs_batch_dimensions(0);
1596 
1597   auto inferred_status =
1598       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1599                                       /*preferred_element_type=*/absl::nullopt);
1600   ASSERT_FALSE(inferred_status.ok());
1601   ASSERT_THAT(inferred_status.status().error_message(),
1602               HasSubstr("Must specify the same number of contracting "
1603                         "dimensions for lhs and rhs."));
1604 }
1605 
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1606 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1607   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1608   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1609   Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1610 
1611   DotDimensionNumbers dot_dnums;
1612   dot_dnums.add_lhs_contracting_dimensions(2);
1613   dot_dnums.add_lhs_contracting_dimensions(3);
1614   dot_dnums.add_lhs_batch_dimensions(0);
1615 
1616   dot_dnums.add_rhs_contracting_dimensions(1);
1617   dot_dnums.add_rhs_contracting_dimensions(2);
1618   dot_dnums.add_rhs_batch_dimensions(0);
1619 
1620   auto inferred_status =
1621       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1622                                       /*preferred_element_type=*/absl::nullopt);
1623   EXPECT_TRUE(inferred_status.ok());
1624   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1625 }
1626 
TEST_F(ShapeInferenceTest,ErrorSetDimensionSize)1627 TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
1628   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1629   Shape val_shape = ShapeUtil::MakeShape(S32, {1});
1630   auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1631       arg_shape, val_shape, /*dimension=*/0);
1632 
1633   EXPECT_FALSE(inferred_status.ok());
1634   EXPECT_THAT(inferred_status.status().error_message(),
1635               HasSubstr("value has to be S32 scalar"));
1636 }
1637 
TEST_F(ShapeInferenceTest,ErrorSetDimensionSizeWrongType)1638 TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
1639   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1640   Shape val_shape = ShapeUtil::MakeShape(U32, {});
1641   auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1642       arg_shape, val_shape, /*dimension=*/0);
1643 
1644   EXPECT_FALSE(inferred_status.ok());
1645   EXPECT_THAT(inferred_status.status().error_message(),
1646               HasSubstr("value has to be S32 scalar"));
1647 }
1648 
1649 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimSizesFails)1650 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
1651   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1652   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1653 
1654   DotDimensionNumbers dot_dnums;
1655   dot_dnums.add_lhs_contracting_dimensions(2);
1656   dot_dnums.add_lhs_batch_dimensions(0);
1657 
1658   dot_dnums.add_rhs_contracting_dimensions(1);
1659   dot_dnums.add_rhs_batch_dimensions(0);
1660 
1661   auto inferred_status =
1662       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1663                                       /*preferred_element_type=*/absl::nullopt);
1664   ASSERT_FALSE(inferred_status.ok());
1665   ASSERT_THAT(inferred_status.status().error_message(),
1666               HasSubstr("Batch dimension sizes must match"));
1667 }
1668 
1669 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimNumbersPasses)1670 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
1671   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1672   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1673 
1674   DotDimensionNumbers dot_dnums;
1675   dot_dnums.add_lhs_contracting_dimensions(2);
1676   dot_dnums.add_lhs_batch_dimensions(0);
1677 
1678   dot_dnums.add_rhs_contracting_dimensions(0);
1679   dot_dnums.add_rhs_batch_dimensions(1);
1680 
1681   auto inferred_status =
1682       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1683                                       /*preferred_element_type=*/absl::nullopt);
1684   ASSERT_TRUE(inferred_status.ok());
1685   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1686                                ShapeUtil::MakeShape(F32, {2, 11, 14})));
1687 }
1688 
1689 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1690 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1691   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1692   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1693 
1694   DotDimensionNumbers dot_dnums;
1695   dot_dnums.add_lhs_contracting_dimensions(3);
1696   dot_dnums.add_lhs_batch_dimensions(0);
1697 
1698   dot_dnums.add_rhs_contracting_dimensions(0);
1699   dot_dnums.add_rhs_batch_dimensions(1);
1700 
1701   auto inferred_status =
1702       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1703                                       /*preferred_element_type=*/absl::nullopt);
1704   ASSERT_FALSE(inferred_status.ok());
1705   ASSERT_THAT(inferred_status.status().error_message(),
1706               HasSubstr("A dimension number is out of range"));
1707 }
1708 
1709 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1710 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1711   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1712   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1713 
1714   DotDimensionNumbers dot_dnums;
1715   dot_dnums.add_lhs_contracting_dimensions(0);
1716   dot_dnums.add_lhs_batch_dimensions(0);
1717 
1718   dot_dnums.add_rhs_contracting_dimensions(0);
1719   dot_dnums.add_rhs_batch_dimensions(1);
1720 
1721   auto inferred_status =
1722       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1723                                       /*preferred_element_type=*/absl::nullopt);
1724   ASSERT_FALSE(inferred_status.ok());
1725   ASSERT_THAT(inferred_status.status().error_message(),
1726               HasSubstr("A dimension number is not unique"));
1727 }
1728 
TEST_F(ShapeInferenceTest,DotWithIntegralPreferredElementType)1729 TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
1730   DotDimensionNumbers dot_dnums;
1731   dot_dnums.add_lhs_contracting_dimensions(1);
1732   dot_dnums.add_rhs_contracting_dimensions(0);
1733   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1734                           ShapeInference::InferDotOpShape(
1735                               ShapeUtil::MakeShape(S8, {32, 32}),
1736                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1737                               /*preferred_element_type=*/S32));
1738   EXPECT_TRUE(
1739       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1740 }
1741 
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeSameAsInferredType)1742 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
1743   DotDimensionNumbers dot_dnums;
1744   dot_dnums.add_lhs_contracting_dimensions(1);
1745   dot_dnums.add_rhs_contracting_dimensions(0);
1746   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1747                           ShapeInference::InferDotOpShape(
1748                               ShapeUtil::MakeShape(BF16, {32, 32}),
1749                               ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1750                               /*preferred_element_type=*/F32));
1751   EXPECT_TRUE(
1752       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1753 }
1754 
TEST_F(ShapeInferenceTest,FloatingPointDotWithNarrowerPreferredElementType)1755 TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
1756   DotDimensionNumbers dot_dnums;
1757   dot_dnums.add_lhs_contracting_dimensions(1);
1758   dot_dnums.add_rhs_contracting_dimensions(0);
1759   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1760                           ShapeInference::InferDotOpShape(
1761                               ShapeUtil::MakeShape(BF16, {32, 32}),
1762                               ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1763                               /*preferred_element_type=*/BF16));
1764   EXPECT_TRUE(
1765       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
1766 }
1767 
TEST_F(ShapeInferenceTest,FloatingPointDotWithInvalidPreferredElementType)1768 TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
1769   DotDimensionNumbers dot_dnums;
1770   dot_dnums.add_lhs_contracting_dimensions(1);
1771   dot_dnums.add_rhs_contracting_dimensions(0);
1772   auto inferred_status = ShapeInference::InferDotOpShape(
1773                              ShapeUtil::MakeShape(BF16, {32, 32}),
1774                              ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
1775                              /*preferred_element_type=*/S32)
1776                              .status();
1777   ASSERT_FALSE(inferred_status.ok());
1778   ASSERT_THAT(inferred_status.error_message(),
1779               HasSubstr("must both be integral or both be floating point"));
1780 }
1781 
TEST_F(ShapeInferenceTest,IntegralDotWithFloatingPointPreferredElementType)1782 TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
1783   DotDimensionNumbers dot_dnums;
1784   dot_dnums.add_lhs_contracting_dimensions(1);
1785   dot_dnums.add_rhs_contracting_dimensions(0);
1786   auto inferred_status = ShapeInference::InferDotOpShape(
1787                              ShapeUtil::MakeShape(S8, {32, 32}),
1788                              ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1789                              /*preferred_element_type=*/F32)
1790                              .status();
1791   ASSERT_FALSE(inferred_status.ok());
1792   ASSERT_THAT(inferred_status.error_message(),
1793               HasSubstr("must both be integral or both be floating point"));
1794 }
1795 
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeWithDifferentSignedness)1796 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
1797   DotDimensionNumbers dot_dnums;
1798   dot_dnums.add_lhs_contracting_dimensions(1);
1799   dot_dnums.add_rhs_contracting_dimensions(0);
1800   auto inferred_status = ShapeInference::InferDotOpShape(
1801                              ShapeUtil::MakeShape(S8, {32, 32}),
1802                              ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1803                              /*preferred_element_type=*/U32)
1804                              .status();
1805   ASSERT_FALSE(inferred_status.ok());
1806   ASSERT_THAT(inferred_status.error_message(),
1807               HasSubstr("must have the same signedness as the original type"));
1808 }
1809 
TEST_F(ShapeInferenceTest,DotWithNarrowerPreferredElementType)1810 TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
1811   DotDimensionNumbers dot_dnums;
1812   dot_dnums.add_lhs_contracting_dimensions(1);
1813   dot_dnums.add_rhs_contracting_dimensions(0);
1814   auto inferred_status = ShapeInference::InferDotOpShape(
1815                              ShapeUtil::MakeShape(S8, {32, 32}),
1816                              ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1817                              /*preferred_element_type=*/S8)
1818                              .status();
1819   ASSERT_FALSE(inferred_status.ok());
1820   ASSERT_THAT(inferred_status.error_message(),
1821               HasSubstr("must not be narrower than the original type"));
1822 }
1823 
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1824 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1825   // Test variations of broadcasting a vector for a binary add with a
1826   // matrix.
1827   const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1828   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1829   const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1830 
1831   auto inferred_status_match =
1832       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1833   ASSERT_IS_OK(inferred_status_match.status());
1834   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1835 
1836   auto inferred_status_mismatch =
1837       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1838   ASSERT_FALSE(inferred_status_mismatch.ok());
1839 
1840   inferred_status_match =
1841       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1842   ASSERT_IS_OK(inferred_status_match.status());
1843   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1844 
1845   inferred_status_mismatch =
1846       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1847   ASSERT_FALSE(inferred_status_mismatch.ok());
1848 }
1849 
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1850 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1851   // Test variations of broadcasting a matrix for a binary add with a cube.
1852   const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1853   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1854   const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1855   const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1856 
1857   auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1858       HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1859   ASSERT_IS_OK(inferred_status_match.status());
1860   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1861 
1862   inferred_status_match = ShapeInference::InferBinaryOpShape(
1863       HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1864   ASSERT_IS_OK(inferred_status_match.status());
1865   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1866 
1867   inferred_status_match = ShapeInference::InferBinaryOpShape(
1868       HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1869   ASSERT_IS_OK(inferred_status_match.status());
1870   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1871 }
1872 
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1873 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1874   // Test various errors with the broadcast argument.
1875   const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1876   const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1877   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1878   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1879   const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1880 
1881   // "magical" broadcast rejected
1882   auto inferred_status_error1 =
1883       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1884   ASSERT_FALSE(inferred_status_error1.ok());
1885   ASSERT_THAT(inferred_status_error1.status().error_message(),
1886               HasSubstr("Automatic"));
1887 
1888   // broadcast_dimension out of bounds for tensor's rank
1889   auto inferred_status_error2 =
1890       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1891   ASSERT_FALSE(inferred_status_error2.ok());
1892   ASSERT_THAT(inferred_status_error2.status().error_message(),
1893               ContainsRegex("Broadcast dimension number .* too large"));
1894 
1895   // broadcast_dimension doesn't match corresponding dimension
1896   auto inferred_status_error3 =
1897       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1898   ASSERT_FALSE(inferred_status_error3.ok());
1899   ASSERT_THAT(inferred_status_error3.status().error_message(),
1900               HasSubstr("Broadcast dimension 0 mismatch"));
1901 
1902   // broadcast_dimensions list too long
1903   auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1904       HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1905   ASSERT_FALSE(inferred_status_error4.ok());
1906   ASSERT_THAT(inferred_status_error4.status().error_message(),
1907               HasSubstr("broadcast_dimensions has to match"));
1908 
1909   // there's a dimension above the rank of the tensor
1910   auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1911       HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1912   ASSERT_FALSE(inferred_status_error5.ok());
1913   ASSERT_THAT(inferred_status_error5.status().error_message(),
1914               ContainsRegex("dimension number .* too large"));
1915 
1916   // broadcasting dimensions don't match in this order
1917   auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1918       HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1919   ASSERT_FALSE(inferred_status_error6.ok());
1920   ASSERT_THAT(inferred_status_error6.status().error_message(),
1921               HasSubstr("dimension 0 mismatch"));
1922 
1923   // The following two tests make sure that broadcasting dimensions are listed
1924   // in a proper (strictly increasing) order, even if the lower-rank array
1925   // matches the higher-rank array in many different ways.
1926   auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1927       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1928   ASSERT_FALSE(inferred_status_error7.ok());
1929   ASSERT_THAT(inferred_status_error7.status().error_message(),
1930               HasSubstr("dimensions order is wrong"));
1931 
1932   auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1933       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1934   ASSERT_FALSE(inferred_status_error8.ok());
1935   ASSERT_THAT(inferred_status_error8.status().error_message(),
1936               HasSubstr("dimensions order is wrong"));
1937 }
1938 
1939 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1940 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1941   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1942   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1943   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1944   auto inferred_status =
1945       ShapeInference::InferWhileShape(cond, body, result_shape);
1946   ASSERT_IS_OK(inferred_status.status());
1947   Shape inferred = inferred_status.ValueOrDie();
1948   ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1949 }
1950 
1951 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1952 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1953   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1954   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1955   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1956 
1957   auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1958   auto inferred_status_error1 =
1959       ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1960   ASSERT_FALSE(inferred_status_error1.ok());
1961   ASSERT_THAT(inferred_status_error1.status().error_message(),
1962               HasSubstr("Condition must take 1 arguments"));
1963 
1964   auto bad_shape_2 =
1965       ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1966   auto inferred_status_error2 =
1967       ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
1968   ASSERT_FALSE(inferred_status_error2.ok());
1969   ASSERT_THAT(inferred_status_error2.status().error_message(),
1970               HasSubstr("Body must take 1 arguments"));
1971 
1972   auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
1973   auto inferred_status_error3 =
1974       ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
1975   ASSERT_FALSE(inferred_status_error3.ok());
1976   ASSERT_THAT(inferred_status_error3.status().error_message(),
1977               HasSubstr("Condition must return a boolean"));
1978 
1979   auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
1980   auto inferred_status_error4 =
1981       ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
1982   ASSERT_FALSE(inferred_status_error4.ok());
1983   ASSERT_THAT(inferred_status_error4.status().error_message(),
1984               HasSubstr("parameter of condition and body"));
1985 }
1986 
1987 // Tests for the concatenate instruction with dynamic shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithDynamicShapes)1988 TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
1989   auto dynamic_shape_1 =
1990       ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
1991   auto dynamic_shape_2 =
1992       ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
1993   auto inferred_status = ShapeInference::InferConcatOpShape(
1994       {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
1995   ASSERT_IS_OK(inferred_status.status());
1996   Shape inferred = inferred_status.ValueOrDie();
1997   ASSERT_TRUE(ShapeUtil::Equal(
1998       ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
1999 }
2000 
2001 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)2002 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
2003   auto inferred_status_1 = ShapeInference::InferConcatOpShape(
2004       {&vector_32_, &vector_64_}, /*dimension=*/0);
2005   ASSERT_IS_OK(inferred_status_1.status());
2006   Shape inferred_1 = inferred_status_1.ValueOrDie();
2007   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
2008 
2009   auto inferred_status_2 = ShapeInference::InferConcatOpShape(
2010       {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
2011   ASSERT_IS_OK(inferred_status_2.status());
2012   Shape inferred_2 = inferred_status_2.ValueOrDie();
2013   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
2014 
2015   auto inferred_status_3 = ShapeInference::InferConcatOpShape(
2016       {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
2017   ASSERT_IS_OK(inferred_status_3.status());
2018   Shape inferred_3 = inferred_status_3.ValueOrDie();
2019   ASSERT_TRUE(
2020       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
2021 }
2022 
2023 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)2024 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
2025   auto inferred_status_error1 =
2026       ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
2027   ASSERT_FALSE(inferred_status_error1.ok());
2028   ASSERT_THAT(inferred_status_error1.status().error_message(),
2029               HasSubstr("Concatenate expects at least one argument"));
2030 
2031   auto inferred_status_error2 =
2032       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
2033   ASSERT_FALSE(inferred_status_error2.ok());
2034   ASSERT_THAT(inferred_status_error2.status().error_message(),
2035               HasSubstr("dimension out of bounds: -1"));
2036 
2037   auto inferred_status_error3 =
2038       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
2039   ASSERT_FALSE(inferred_status_error3.ok());
2040   ASSERT_THAT(inferred_status_error3.status().error_message(),
2041               HasSubstr("dimension out of bounds: 1"));
2042 
2043   Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
2044   auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
2045       {&vector_32_, &tuple}, /*dimension=*/0);
2046   ASSERT_FALSE(inferred_status_error4.ok());
2047   ASSERT_THAT(
2048       inferred_status_error4.status().error_message(),
2049       HasSubstr("Expected array argument for operand of concatenation"));
2050 
2051   const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
2052   auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
2053       {&vector_32_, &vector_s32}, /*dimension=*/0);
2054   ASSERT_FALSE(inferred_status_error5.ok());
2055   ASSERT_THAT(inferred_status_error5.status().error_message(),
2056               HasSubstr("concatenate arrays with different element types"));
2057 
2058   auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
2059       {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
2060   ASSERT_FALSE(inferred_status_error6.ok());
2061   ASSERT_THAT(inferred_status_error6.status().error_message(),
2062               HasSubstr("concatenate arrays that differ in "
2063                         "dimensions other than the one being "
2064                         "concatenated"));
2065 }
2066 
TEST_F(ShapeInferenceTest,Pad)2067 TEST_F(ShapeInferenceTest, Pad) {
2068   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2069   Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
2070   // Padding for dimension 0: {low: 0, high: 2, interior: 3}
2071   // Padding for dimension 1: {low: 1, high: 5, interior: 0}
2072   PaddingConfig padding_config;
2073   auto dimension0 = padding_config.add_dimensions();
2074   dimension0->set_edge_padding_low(0);
2075   dimension0->set_edge_padding_high(2);
2076   dimension0->set_interior_padding(3);
2077   auto dimension1 = padding_config.add_dimensions();
2078   dimension1->set_edge_padding_low(1);
2079   dimension1->set_edge_padding_high(5);
2080   dimension1->set_interior_padding(0);
2081 
2082   auto inferred_status = ShapeInference::InferPadShape(
2083       input_shape, padding_value_shape, padding_config);
2084   ASSERT_IS_OK(inferred_status.status());
2085   Shape inferred_shape = inferred_status.ValueOrDie();
2086   ASSERT_TRUE(
2087       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
2088 
2089   dimension1->set_edge_padding_low(-20);
2090   dimension1->set_edge_padding_high(-10);
2091   auto negative_dimension_size = ShapeInference::InferPadShape(
2092       input_shape, padding_value_shape, padding_config);
2093   ASSERT_FALSE(negative_dimension_size.ok());
2094   ASSERT_THAT(negative_dimension_size.status().error_message(),
2095               HasSubstr("negative size for dimension 1"));
2096 }
2097 
TEST_F(ShapeInferenceTest,Reverse)2098 TEST_F(ShapeInferenceTest, Reverse) {
2099   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2100 
2101   auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
2102   ASSERT_IS_OK(inferred_status.status());
2103   Shape inferred_shape = inferred_status.ValueOrDie();
2104   ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
2105 }
2106 
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)2107 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
2108   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2109 
2110   auto inferred_status_error0 =
2111       ShapeInference::InferReverseShape(input_shape, {0, 2});
2112   ASSERT_FALSE(inferred_status_error0.ok());
2113   ASSERT_THAT(inferred_status_error0.status().error_message(),
2114               HasSubstr("out-of-bounds"));
2115 
2116   auto inferred_status_error1 =
2117       ShapeInference::InferReverseShape(input_shape, {0, -1});
2118   ASSERT_FALSE(inferred_status_error1.ok());
2119   ASSERT_THAT(inferred_status_error1.status().error_message(),
2120               HasSubstr("out-of-bounds"));
2121 
2122   auto inferred_status_error2 =
2123       ShapeInference::InferReverseShape(input_shape, {0, 0});
2124   ASSERT_FALSE(inferred_status_error2.ok());
2125   ASSERT_THAT(inferred_status_error2.status().error_message(),
2126               HasSubstr("duplicated"));
2127 
2128   Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
2129   auto inferred_status_error3 =
2130       ShapeInference::InferReverseShape(tuple_shape, {0});
2131   ASSERT_FALSE(inferred_status_error3.ok());
2132   ASSERT_THAT(inferred_status_error3.status().error_message(),
2133               HasSubstr("Expected array argument"));
2134 }
2135 
TEST_F(ShapeInferenceTest,Call)2136 TEST_F(ShapeInferenceTest, Call) {
2137   auto inferred_status0 =
2138       ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
2139   EXPECT_IS_OK(inferred_status0.status());
2140   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2141 
2142   auto inferred_status1 = ShapeInference::InferCallShape(
2143       {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
2144       ShapeUtil::MakeProgramShape(
2145           {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
2146   EXPECT_IS_OK(inferred_status1.status());
2147   EXPECT_TRUE(
2148       ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
2149 
2150   auto inferred_status_error0 = ShapeInference::InferCallShape(
2151       {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
2152   EXPECT_FALSE(inferred_status_error0.ok());
2153   EXPECT_THAT(inferred_status_error0.status().error_message(),
2154               HasSubstr("arity must match"));
2155 
2156   auto inferred_status_error1 = ShapeInference::InferCallShape(
2157       {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
2158   EXPECT_FALSE(inferred_status_error1.ok());
2159   EXPECT_THAT(inferred_status_error1.status().error_message(),
2160               HasSubstr("arity must match"));
2161 
2162   auto inferred_status_error2 = ShapeInference::InferCallShape(
2163       {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
2164   EXPECT_FALSE(inferred_status_error2.ok());
2165   EXPECT_THAT(inferred_status_error2.status().error_message(),
2166               HasSubstr("parameter must match argument"));
2167 }
2168 
TEST_F(ShapeInferenceTest,Transpose)2169 TEST_F(ShapeInferenceTest, Transpose) {
2170   Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
2171   auto inferred_shape_and_status =
2172       ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
2173   EXPECT_IS_OK(inferred_shape_and_status);
2174   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2175   EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
2176                                     ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
2177 }
2178 
TEST_F(ShapeInferenceTest,Rank1Transpose)2179 TEST_F(ShapeInferenceTest, Rank1Transpose) {
2180   Shape a_shape = ShapeUtil::MakeShape(F32, {5});
2181   auto inferred_shape_and_status =
2182       ShapeInference::InferTransposeShape(a_shape, {0});
2183   EXPECT_IS_OK(inferred_shape_and_status);
2184   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2185   EXPECT_TRUE(
2186       ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
2187 }
2188 
TEST_F(ShapeInferenceTest,ConditionalPred)2189 TEST_F(ShapeInferenceTest, ConditionalPred) {
2190   auto inferred_status0 = ShapeInference::InferConditionalShape(
2191       pred_,
2192       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2193        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2194       {vector_32_, vector_64_});
2195   EXPECT_IS_OK(inferred_status0.status());
2196   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2197 
2198   auto inferred_status1 = ShapeInference::InferConditionalShape(
2199       pred_,
2200       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2201        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
2202       {matrix_32_48_, vector_32_});
2203   EXPECT_IS_OK(inferred_status1.status());
2204   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2205 
2206   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2207   auto inferred_status2 = ShapeInference::InferConditionalShape(
2208       pred_,
2209       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2210        ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2211       {matrix_32_48_, tuple_f32_v32});
2212   EXPECT_IS_OK(inferred_status2.status());
2213   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2214 
2215   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2216       f32_,
2217       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2218        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2219       {vector_32_, vector_64_});
2220   EXPECT_FALSE(inferred_status_error0.ok());
2221   EXPECT_THAT(inferred_status_error0.status().error_message(),
2222               HasSubstr("must be bool or int32"));
2223 
2224   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2225       pred_,
2226       {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2227        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2228       {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
2229   EXPECT_FALSE(inferred_status_error1.ok());
2230   EXPECT_THAT(inferred_status_error1.status().error_message(),
2231               HasSubstr("branch computation 0 must take 1 argument"));
2232 
2233   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2234       pred_,
2235       {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2236        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2237       {vector_32_, vector_64_});
2238   EXPECT_FALSE(inferred_status_error2.ok());
2239   EXPECT_THAT(inferred_status_error2.status().error_message(),
2240               HasSubstr("branch operand 0 must match the shape of the only "
2241                         "parameter of branch computation 0"));
2242 
2243   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2244       pred_,
2245       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2246        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
2247       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
2248   EXPECT_FALSE(inferred_status_error3.ok());
2249   EXPECT_THAT(inferred_status_error3.status().error_message(),
2250               HasSubstr("branch computation 1 must take 1 argument"));
2251 
2252   auto inferred_status_error4 = ShapeInference::InferConditionalShape(
2253       pred_,
2254       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2255        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2256       {vector_32_, vector_64_});
2257   EXPECT_FALSE(inferred_status_error4.ok());
2258   EXPECT_THAT(inferred_status_error4.status().error_message(),
2259               HasSubstr("branch operand 1 must match the shape of the only "
2260                         "parameter of branch computation 1"));
2261 
2262   auto inferred_status_error5 = ShapeInference::InferConditionalShape(
2263       pred_,
2264       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2265        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2266       {vector_32_, vector_64_});
2267   EXPECT_FALSE(inferred_status_error5.ok());
2268   EXPECT_THAT(inferred_status_error5.status().error_message(),
2269               HasSubstr("the result of branch 0 computation and branch 1 "
2270                         "computation must have the same shape"));
2271 }
2272 
TEST_F(ShapeInferenceTest,ConditionalIndexed)2273 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
2274   auto r0s32 = ShapeUtil::MakeShape(S32, {});
2275   auto inferred_status0 = ShapeInference::InferConditionalShape(
2276       r0s32,
2277       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2278        ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2279        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2280       {vector_32_, vector_64_, vector_64_});
2281   EXPECT_IS_OK(inferred_status0.status());
2282   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2283 
2284   auto inferred_status1 = ShapeInference::InferConditionalShape(
2285       r0s32,
2286       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2287        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
2288        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
2289       {matrix_32_48_, vector_32_, matrix_32_48_});
2290   EXPECT_IS_OK(inferred_status1.status());
2291   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2292 
2293   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2294   auto inferred_status2 = ShapeInference::InferConditionalShape(
2295       r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2296       {tuple_f32_v32});
2297   EXPECT_IS_OK(inferred_status2.status());
2298   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2299 
2300   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2301       pred_,
2302       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2303        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2304        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2305       {vector_32_, vector_32_, vector_64_});
2306   EXPECT_FALSE(inferred_status_error0.ok());
2307   EXPECT_THAT(inferred_status_error0.status().error_message(),
2308               HasSubstr("2 == branch_computations.size()"));
2309 
2310   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2311       r0s32,
2312       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2313        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2314        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2315       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
2316        matrix_32_48_});
2317   EXPECT_FALSE(inferred_status_error1.ok());
2318   EXPECT_THAT(inferred_status_error1.status().error_message(),
2319               HasSubstr("branch computation 1 must take 1 argument"));
2320 
2321   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2322       r0s32,
2323       {ShapeUtil::MakeProgramShape({r0s32}, f32_),
2324        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2325        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2326       {r0s32, vector_32_, vector_64_});
2327   EXPECT_FALSE(inferred_status_error2.ok());
2328   EXPECT_THAT(inferred_status_error2.status().error_message(),
2329               HasSubstr("branch operand 2 must match the shape of the only "
2330                         "parameter of branch computation 2"));
2331 
2332   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2333       r0s32,
2334       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2335        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2336        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2337        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2338       {vector_32_, vector_32_, vector_32_, vector_64_});
2339   EXPECT_FALSE(inferred_status_error3.ok());
2340   EXPECT_THAT(inferred_status_error3.status().error_message(),
2341               HasSubstr("the result of branch 0 computation and branch 3 "
2342                         "computation must have the same shape"));
2343 
2344   auto inferred_status_error4 =
2345       ShapeInference::InferConditionalShape(r0s32, {}, {});
2346   EXPECT_FALSE(inferred_status_error4.ok());
2347   EXPECT_THAT(inferred_status_error4.status().error_message(),
2348               HasSubstr("!branch_computations.empty()"));
2349 }
2350 
TEST_F(ShapeInferenceTest,BadSlice)2351 TEST_F(ShapeInferenceTest, BadSlice) {
2352   auto arg = ShapeUtil::MakeShape(F32, {4});
2353   StatusOr<Shape> statusor =
2354       ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
2355   ASSERT_FALSE(statusor.ok());
2356 
2357   LOG(INFO) << statusor.status();
2358 
2359   EXPECT_THAT(statusor.status().error_message(),
2360               HasSubstr("less than or equal to dimension size"))
2361       << statusor.status();
2362   EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
2363       << statusor.status();
2364 }
2365 
TEST_F(ShapeInferenceTest,BadSort)2366 TEST_F(ShapeInferenceTest, BadSort) {
2367   auto keys = ShapeUtil::MakeShape(F32, {4});
2368   auto values = ShapeUtil::MakeShape(F32, {5});
2369   StatusOr<Shape> statusor =
2370       ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
2371   EXPECT_FALSE(statusor.ok());
2372   EXPECT_THAT(statusor.status().error_message(),
2373               HasSubstr("dimensions must match"))
2374       << statusor.status();
2375 }
2376 
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)2377 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
2378   auto keys = ShapeUtil::MakeShape(F32, {4});
2379   auto values_good = ShapeUtil::MakeShape(F32, {4});
2380   auto values_bad = ShapeUtil::MakeShape(F32, {5});
2381   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2382       HloOpcode::kSort, {&keys, &values_good, &values_bad});
2383   EXPECT_FALSE(statusor.ok());
2384   EXPECT_THAT(statusor.status().error_message(),
2385               HasSubstr("dimensions must match"))
2386       << statusor.status();
2387 }
2388 
TEST_F(ShapeInferenceTest,SortManyValues)2389 TEST_F(ShapeInferenceTest, SortManyValues) {
2390   auto keys = ShapeUtil::MakeShape(F32, {4});
2391   auto values_s32 = ShapeUtil::MakeShape(S32, {4});
2392   auto values_u32 = ShapeUtil::MakeShape(U32, {4});
2393   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2394       HloOpcode::kSort, {&keys, &values_s32, &values_u32});
2395   EXPECT_IS_OK(statusor);
2396   Shape inferred_shape = statusor.ValueOrDie();
2397   EXPECT_TRUE(ShapeUtil::Compatible(
2398       inferred_shape,
2399       ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
2400 }
2401 
2402 class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
2403  protected:
2404   const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
2405   const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
2406   const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
2407   const Shape s64_4d_tensor_10_9_8_7_1_ =
2408       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
2409   const Shape s64_4d_tensor_10_9_8_7_5_ =
2410       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
2411   const Shape s64_4d_tensor_5_10_9_7_6_ =
2412       ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
2413   const Shape s64_4d_tensor_10_9_5_7_6_ =
2414       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
2415   const Shape f32_5d_tensor_50_49_48_47_46_ =
2416       ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
2417   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
2418       {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
2419   const ProgramShape to_apply_ =
2420       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
2421 };
2422 
2423 // Shape inference tests for Gather.
2424 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGather)2425 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
2426   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2427                           ShapeInference::InferGatherShape(
2428                               matrix_64_48_, s64_vector_32_,
2429                               HloGatherInstruction::MakeGatherDimNumbers(
2430                                   /*offset_dims=*/{0},
2431                                   /*collapsed_slice_dims=*/{1},
2432                                   /*start_index_map=*/{1},
2433                                   /*index_vector_dim=*/1),
2434                               /*slice_sizes=*/{64, 1}));
2435   EXPECT_TRUE(
2436       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
2437       << ShapeUtil::HumanString(gather_shape);
2438 }
2439 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherV2)2440 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
2441   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2442                           ShapeInference::InferGatherShape(
2443                               matrix_64_48_, s64_vector_32_,
2444                               HloGatherInstruction::MakeGatherDimNumbers(
2445                                   /*offset_dims=*/{1},
2446                                   /*collapsed_slice_dims=*/{0},
2447                                   /*start_index_map=*/{0},
2448                                   /*index_vector_dim=*/1),
2449                               /*slice_sizes=*/{1, 48}));
2450   EXPECT_TRUE(
2451       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
2452       << ShapeUtil::HumanString(gather_shape);
2453 }
2454 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherNd)2455 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
2456   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2457                           ShapeInference::InferGatherShape(
2458                               matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2459                               HloGatherInstruction::MakeGatherDimNumbers(
2460                                   /*offset_dims=*/{4},
2461                                   /*collapsed_slice_dims=*/{0},
2462                                   /*start_index_map=*/{0},
2463                                   /*index_vector_dim=*/4),
2464                               /*slice_sizes=*/{1, 48}));
2465   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2466                                ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
2467       << ShapeUtil::HumanString(gather_shape);
2468 }
2469 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowBatchDynamicSlice)2470 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
2471   TF_ASSERT_OK_AND_ASSIGN(
2472       Shape gather_shape,
2473       ShapeInference::InferGatherShape(
2474           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2475           HloGatherInstruction::MakeGatherDimNumbers(
2476               /*offset_dims=*/{4, 5, 6, 7, 8},
2477               /*collapsed_slice_dims=*/{},
2478               /*start_index_map=*/{0, 1, 2, 3, 4},
2479               /*index_vector_dim=*/4),
2480           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2481   EXPECT_TRUE(ShapeUtil::Equal(
2482       gather_shape,
2483       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
2484       << ShapeUtil::HumanString(gather_shape);
2485 }
2486 
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherEntireDimension)2487 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherEntireDimension) {
2488   TF_ASSERT_OK_AND_ASSIGN(
2489       Shape gather_shape,
2490       ShapeInference::InferGatherShape(
2491           ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
2492           ShapeUtil::MakeShape(S64, {}),
2493           HloGatherInstruction::MakeGatherDimNumbers(
2494               /*offset_dims=*/{0, 1},
2495               /*collapsed_slice_dims=*/{0},
2496               /*start_index_map=*/{0},
2497               /*index_vector_dim=*/0),
2498           /*slice_sizes=*/{1, 2, 1}));
2499   EXPECT_TRUE(ShapeUtil::Equal(
2500       gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
2501       << ShapeUtil::HumanString(gather_shape);
2502 }
2503 
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherCollapsedDimension)2504 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
2505   TF_ASSERT_OK_AND_ASSIGN(
2506       Shape gather_shape,
2507       ShapeInference::InferGatherShape(
2508           ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
2509           ShapeUtil::MakeShape(S64, {}),
2510           HloGatherInstruction::MakeGatherDimNumbers(
2511               /*offset_dims=*/{0, 1},
2512               /*collapsed_slice_dims=*/{0},
2513               /*start_index_map=*/{0},
2514               /*index_vector_dim=*/0),
2515           /*slice_sizes=*/{1, 2, 1}));
2516   EXPECT_TRUE(ShapeUtil::Equal(
2517       gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
2518       << ShapeUtil::HumanString(gather_shape);
2519 }
2520 
TEST_F(ScatterGatherShapeInferenceTest,DynamicIndices)2521 TEST_F(ScatterGatherShapeInferenceTest, DynamicIndices) {
2522   TF_ASSERT_OK_AND_ASSIGN(
2523       Shape gather_shape,
2524       ShapeInference::InferGatherShape(
2525           ShapeUtil::MakeShape(F32, {3, 2, 2}),
2526           ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
2527           HloGatherInstruction::MakeGatherDimNumbers(
2528               /*offset_dims=*/{2, 3},
2529               /*collapsed_slice_dims=*/{0},
2530               /*start_index_map=*/{0, 1},
2531               /*index_vector_dim=*/2),
2532           /*slice_sizes=*/{1, 2, 2}));
2533   EXPECT_TRUE(ShapeUtil::Equal(
2534       gather_shape,
2535       ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
2536       << ShapeUtil::HumanString(gather_shape);
2537 }
2538 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)2539 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
2540   TF_ASSERT_OK_AND_ASSIGN(
2541       Shape gather_shape,
2542       ShapeInference::InferGatherShape(
2543           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2544           HloGatherInstruction::MakeGatherDimNumbers(
2545               /*offset_dims=*/{4, 5, 6, 7, 8},
2546               /*collapsed_slice_dims=*/{},
2547               /*start_index_map=*/{0, 1, 2, 3, 4},
2548               /*index_vector_dim=*/2),
2549           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2550 
2551   EXPECT_TRUE(ShapeUtil::Equal(
2552       gather_shape,
2553       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2554       << ShapeUtil::HumanString(gather_shape);
2555 }
2556 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)2557 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
2558   TF_ASSERT_OK_AND_ASSIGN(
2559       Shape gather_shape,
2560       ShapeInference::InferGatherShape(
2561           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2562           HloGatherInstruction::MakeGatherDimNumbers(
2563               /*offset_dims=*/{4, 5, 6, 7, 8},
2564               /*collapsed_slice_dims=*/{},
2565               /*start_index_map=*/{0, 1, 2, 3, 4},
2566               /*index_vector_dim=*/0),
2567           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2568 
2569   EXPECT_TRUE(ShapeUtil::Equal(
2570       gather_shape,
2571       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2572       << ShapeUtil::HumanString(gather_shape);
2573 }
2574 
TEST_F(ScatterGatherShapeInferenceTest,NoOutputGatherDims)2575 TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
2576   // This is equivalent to a dynamic slice.
2577   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2578                           ShapeInference::InferGatherShape(
2579                               f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2580                               HloGatherInstruction::MakeGatherDimNumbers(
2581                                   /*offset_dims=*/{0, 1, 2, 3, 4},
2582                                   /*collapsed_slice_dims=*/{},
2583                                   /*start_index_map=*/{0, 1, 2, 3, 4},
2584                                   /*index_vector_dim=*/0),
2585                               /*slice_sizes=*/{30, 29, 28, 27, 26}));
2586 
2587   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2588                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
2589       << ShapeUtil::HumanString(gather_shape);
2590 }
2591 
TEST_F(ScatterGatherShapeInferenceTest,ScalarGatherIndices)2592 TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
2593   // The gather indices "tensor" is a scalar S here that's used to slice out
2594   // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
2595   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2596                           ShapeInference::InferGatherShape(
2597                               f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2598                               HloGatherInstruction::MakeGatherDimNumbers(
2599                                   /*offset_dims=*/{0, 1, 2, 3},
2600                                   /*collapsed_slice_dims=*/{0},
2601                                   /*start_index_map=*/{0},
2602                                   /*index_vector_dim=*/0),
2603                               /*slice_sizes=*/{1, 30, 29, 28, 27}));
2604 
2605   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2606                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
2607       << ShapeUtil::HumanString(gather_shape);
2608 }
2609 
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedTensorInput)2610 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
2611   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2612       tuple_shape_, s64_vector_32_,
2613       HloGatherInstruction::MakeGatherDimNumbers(
2614           /*offset_dims=*/{0},
2615           /*collapsed_slice_dims=*/{1},
2616           /*start_index_map=*/{1},
2617           /*index_vector_dim=*/1),
2618       /*slice_sizes=*/{64, 1});
2619   ASSERT_FALSE(statusor.ok());
2620   EXPECT_THAT(statusor.status().error_message(),
2621               HasSubstr("Expected array argument for input"))
2622       << statusor.status();
2623 }
2624 
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedGatherIndicesInput)2625 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
2626   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2627       s64_vector_32_, tuple_shape_,
2628       HloGatherInstruction::MakeGatherDimNumbers(
2629           /*offset_dims=*/{0},
2630           /*collapsed_slice_dims=*/{1},
2631           /*start_index_map=*/{1},
2632           /*index_vector_dim=*/0),
2633       /*slice_sizes=*/{64, 1});
2634   ASSERT_FALSE(statusor.ok());
2635   EXPECT_THAT(statusor.status().error_message(),
2636               HasSubstr("Expected array argument for gather indices"))
2637       << statusor.status();
2638 }
2639 
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointGatherIndicesInput)2640 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
2641   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2642       s64_vector_32_, vector_32_,
2643       HloGatherInstruction::MakeGatherDimNumbers(
2644           /*offset_dims=*/{0},
2645           /*collapsed_slice_dims=*/{1},
2646           /*start_index_map=*/{1},
2647           /*index_vector_dim=*/0),
2648       /*slice_sizes=*/{64, 1});
2649   ASSERT_FALSE(statusor.ok());
2650   EXPECT_THAT(statusor.status().error_message(),
2651               HasSubstr("Gather indices parameter must be an integral tensor"))
2652       << statusor.status();
2653 }
2654 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)2655 TEST_F(ScatterGatherShapeInferenceTest,
2656        InvalidGatherDimNumbers_NonAscendingWindowIndices) {
2657   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2658       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2659       HloGatherInstruction::MakeGatherDimNumbers(
2660           /*offset_dims=*/{4, 5, 6, 8, 7},
2661           /*collapsed_slice_dims=*/{},
2662           /*start_index_map=*/{0, 1, 2, 3, 4},
2663           /*index_vector_dim=*/4),
2664       /*slice_sizes=*/{30, 29, 28, 27, 26});
2665   ASSERT_FALSE(statusor.ok());
2666   EXPECT_THAT(
2667       statusor.status().error_message(),
2668       HasSubstr("Output window dimensions in gather op must be ascending"))
2669       << statusor.status();
2670 }
2671 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)2672 TEST_F(ScatterGatherShapeInferenceTest,
2673        InvalidGatherDimNumbers_RepeatedWindowIndices) {
2674   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2675       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2676       HloGatherInstruction::MakeGatherDimNumbers(
2677           /*offset_dims=*/{4, 5, 6, 7, 7},
2678           /*collapsed_slice_dims=*/{},
2679           /*start_index_map=*/{0, 1, 2, 3, 4},
2680           /*index_vector_dim=*/4),
2681       /*slice_sizes=*/{30, 29, 28, 27, 26});
2682   ASSERT_FALSE(statusor.ok());
2683   EXPECT_THAT(
2684       statusor.status().error_message(),
2685       HasSubstr("Output window dimensions in gather op must not repeat"))
2686       << statusor.status();
2687 }
2688 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)2689 TEST_F(ScatterGatherShapeInferenceTest,
2690        InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
2691   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2692       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2693       HloGatherInstruction::MakeGatherDimNumbers(
2694           /*offset_dims=*/{4, 5, 99, 100, 101},
2695           /*collapsed_slice_dims=*/{},
2696           /*start_index_map=*/{0, 1, 2, 3, 4},
2697           /*index_vector_dim=*/4),
2698       /*slice_sizes=*/{30, 29, 28, 27, 26});
2699   ASSERT_FALSE(statusor.ok());
2700   EXPECT_THAT(statusor.status().error_message(),
2701               HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2702       << statusor.status();
2703 }
2704 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2705 TEST_F(ScatterGatherShapeInferenceTest,
2706        InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2707   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2708       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2709       HloGatherInstruction::MakeGatherDimNumbers(
2710           /*offset_dims=*/{4, 5, 6, 7, 9},
2711           /*collapsed_slice_dims=*/{},
2712           /*start_index_map=*/{0, 1, 2, 3, 4},
2713           /*index_vector_dim=*/4),
2714       /*slice_sizes=*/{30, 29, 28, 27, 26});
2715   ASSERT_FALSE(statusor.ok());
2716   EXPECT_THAT(statusor.status().error_message(),
2717               HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2718       << statusor.status();
2719 }
2720 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2721 TEST_F(ScatterGatherShapeInferenceTest,
2722        InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2723   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2724       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2725       HloGatherInstruction::MakeGatherDimNumbers(
2726           /*offset_dims=*/{4, 5, 6, 7, 8},
2727           /*collapsed_slice_dims=*/{4},
2728           /*start_index_map=*/{0, 1, 2, 3, 4},
2729           /*index_vector_dim=*/4),
2730       /*slice_sizes=*/{30, 29, 28, 27, 26});
2731   ASSERT_FALSE(statusor.ok());
2732   EXPECT_THAT(
2733       statusor.status().error_message(),
2734       HasSubstr("All components of the offset index in a gather op must either "
2735                 "be a offset dimension or explicitly collapsed"))
2736       << statusor.status();
2737 }
2738 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2739 TEST_F(ScatterGatherShapeInferenceTest,
2740        InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2741   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2742       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2743       HloGatherInstruction::MakeGatherDimNumbers(
2744           /*offset_dims=*/{4, 5, 6, 7, 8},
2745           /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2746           /*start_index_map=*/{0, 1, 2, 3, 4},
2747           /*index_vector_dim=*/4),
2748       /*slice_sizes=*/{30, 29, 28, 27, 26});
2749   ASSERT_FALSE(statusor.ok());
2750   EXPECT_THAT(statusor.status().error_message(),
2751               HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2752                         "range is [0, 5), got: 19"))
2753       << statusor.status();
2754 }
2755 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2756 TEST_F(ScatterGatherShapeInferenceTest,
2757        InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2758   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2759       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2760       HloGatherInstruction::MakeGatherDimNumbers(
2761           /*offset_dims=*/{4, 5, 6, 7, 8},
2762           /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2763           /*start_index_map=*/{0, 1, 2, 3, 4},
2764           /*index_vector_dim=*/4),
2765       /*slice_sizes=*/{30, 29, 28, 27, 26});
2766   ASSERT_FALSE(statusor.ok());
2767   EXPECT_THAT(statusor.status().error_message(),
2768               HasSubstr("Repeated dimensions not allowed in "
2769                         "collapsed_slice_dims in gather op"))
2770       << statusor.status();
2771 }
2772 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2773 TEST_F(ScatterGatherShapeInferenceTest,
2774        InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2775   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2776       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2777       HloGatherInstruction::MakeGatherDimNumbers(
2778           /*offset_dims=*/{4, 5, 6, 7, 8},
2779           /*collapsed_slice_dims=*/{},
2780           /*start_index_map=*/{0, 1, 2, 3},
2781           /*index_vector_dim=*/4),
2782       /*slice_sizes=*/{30, 29, 28, 27, 26});
2783   ASSERT_FALSE(statusor.ok());
2784   EXPECT_THAT(statusor.status().error_message(),
2785               HasSubstr("Gather op has 4 elements in start_index_map and "
2786                         "the bound of dimension index_vector_dim=4 of "
2787                         "start_indices is 5. These two numbers must be equal."))
2788       << statusor.status();
2789 }
2790 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2791 TEST_F(ScatterGatherShapeInferenceTest,
2792        InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2793   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2794       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2795       HloGatherInstruction::MakeGatherDimNumbers(
2796           /*offset_dims=*/{4, 5, 6, 7, 8},
2797           /*collapsed_slice_dims=*/{},
2798           /*start_index_map=*/{0, 1, 2, 3, 7},
2799           /*index_vector_dim=*/4),
2800       /*slice_sizes=*/{30, 29, 28, 27, 26});
2801   ASSERT_FALSE(statusor.ok());
2802   EXPECT_THAT(statusor.status().error_message(),
2803               HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2804       << statusor.status();
2805 }
2806 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2807 TEST_F(ScatterGatherShapeInferenceTest,
2808        InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2809   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2810       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2811       HloGatherInstruction::MakeGatherDimNumbers(
2812           /*offset_dims=*/{4, 5, 6, 7, 8},
2813           /*collapsed_slice_dims=*/{},
2814           /*start_index_map=*/{0, 1, 2, 3, 3},
2815           /*index_vector_dim=*/4),
2816       /*slice_sizes=*/{30, 29, 28, 27, 26});
2817   ASSERT_FALSE(statusor.ok());
2818   EXPECT_THAT(
2819       statusor.status().error_message(),
2820       HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2821       << statusor.status();
2822 }
2823 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2824 TEST_F(ScatterGatherShapeInferenceTest,
2825        InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2826   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2827       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2828       HloGatherInstruction::MakeGatherDimNumbers(
2829           /*offset_dims=*/{4, 5, 6, 7, 8},
2830           /*collapsed_slice_dims=*/{2, 1},
2831           /*start_index_map=*/{0, 1, 2, 3, 4},
2832           /*index_vector_dim=*/4),
2833       /*slice_sizes=*/{1, 1, 28, 27, 26});
2834   ASSERT_FALSE(statusor.ok());
2835   EXPECT_THAT(statusor.status().error_message(),
2836               HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2837       << statusor.status();
2838 }
2839 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2840 TEST_F(ScatterGatherShapeInferenceTest,
2841        InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2842   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2843       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2844       HloGatherInstruction::MakeGatherDimNumbers(
2845           /*offset_dims=*/{4, 5, 6, 7},
2846           /*collapsed_slice_dims=*/{2},
2847           /*start_index_map=*/{0, 1, 2, 3, 4},
2848           /*index_vector_dim=*/4),
2849       /*slice_sizes=*/{30, 29, 1, 300, 26});
2850   ASSERT_FALSE(statusor.ok());
2851   EXPECT_THAT(statusor.status().error_message(),
2852               HasSubstr("Slice size at index 3 in gather op is out of range, "
2853                         "must be within [0, 48), got 300."))
2854       << statusor.status();
2855 }
2856 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2857 TEST_F(ScatterGatherShapeInferenceTest,
2858        InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2859   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2860       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2861       HloGatherInstruction::MakeGatherDimNumbers(
2862           /*offset_dims=*/{4, 5, 6, 7, 8},
2863           /*collapsed_slice_dims=*/{},
2864           /*start_index_map=*/{0, 1, 2, 3, 4},
2865           /*index_vector_dim=*/4),
2866       /*slice_sizes=*/{30, 29, 28, 26});
2867   ASSERT_FALSE(statusor.ok());
2868   EXPECT_THAT(
2869       statusor.status().error_message(),
2870       HasSubstr("Gather op must have one slice size for every input dimension"))
2871       << statusor.status();
2872 }
2873 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2874 TEST_F(ScatterGatherShapeInferenceTest,
2875        InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2876   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2877       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2878       HloGatherInstruction::MakeGatherDimNumbers(
2879           /*offset_dims=*/{4, 5, 6, 7},
2880           /*collapsed_slice_dims=*/{1},
2881           /*start_index_map=*/{0, 1, 2, 3, 4},
2882           /*index_vector_dim=*/4),
2883       /*slice_sizes=*/{30, 29, 28, 26, 20});
2884   ASSERT_FALSE(statusor.ok());
2885   EXPECT_THAT(
2886       statusor.status().error_message(),
2887       HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
2888                 "but bound is 29 for index 1 at position 0."))
2889       << statusor.status();
2890 }
2891 
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2892 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2893   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2894       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2895       HloGatherInstruction::MakeGatherDimNumbers(
2896           /*offset_dims=*/{4, 5, 6, 7, 8},
2897           /*collapsed_slice_dims=*/{},
2898           /*start_index_map=*/{0, 1, 2, 3, 4},
2899           /*index_vector_dim=*/32),
2900       /*slice_sizes=*/{30, 29, 28, 27, 26});
2901 
2902   ASSERT_FALSE(statusor.ok());
2903   EXPECT_THAT(statusor.status().error_message(),
2904               HasSubstr("Gather index leaf dimension must be within [0, "
2905                         "rank(start_indices) + 1)"))
2906       << statusor.status();
2907 }
2908 
2909 // Shape inference tests for Scatter.
2910 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdates)2911 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
2912   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2913                           ShapeInference::InferScatterShape(
2914                               matrix_64_48_, s64_vector_32_,
2915                               ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
2916                               HloScatterInstruction::MakeScatterDimNumbers(
2917                                   /*update_window_dims=*/{0},
2918                                   /*inserted_window_dims=*/{1},
2919                                   /*scatter_dims_to_operand_dims=*/{1},
2920                                   /*index_vector_dim=*/1)));
2921   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2922       << ShapeUtil::HumanString(scatter_shape);
2923 }
2924 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdatesV2)2925 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
2926   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2927                           ShapeInference::InferScatterShape(
2928                               matrix_64_48_, s64_vector_32_,
2929                               ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
2930                               HloScatterInstruction::MakeScatterDimNumbers(
2931                                   /*update_window_dims=*/{1},
2932                                   /*inserted_window_dims=*/{0},
2933                                   /*scatter_dims_to_operand_dims=*/{0},
2934                                   /*index_vector_dim=*/1)));
2935   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2936       << ShapeUtil::HumanString(scatter_shape);
2937 }
2938 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdates)2939 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
2940   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2941                           ShapeInference::InferScatterShape(
2942                               matrix_64_48_, s64_vector_32_,
2943                               ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
2944                               HloScatterInstruction::MakeScatterDimNumbers(
2945                                   /*update_window_dims=*/{0},
2946                                   /*inserted_window_dims=*/{1},
2947                                   /*scatter_dims_to_operand_dims=*/{1},
2948                                   /*index_vector_dim=*/1)));
2949   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2950       << ShapeUtil::HumanString(scatter_shape);
2951 }
2952 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdatesV2)2953 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
2954   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2955                           ShapeInference::InferScatterShape(
2956                               matrix_64_48_, s64_vector_32_,
2957                               ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
2958                               HloScatterInstruction::MakeScatterDimNumbers(
2959                                   /*update_window_dims=*/{1},
2960                                   /*inserted_window_dims=*/{0},
2961                                   /*scatter_dims_to_operand_dims=*/{0},
2962                                   /*index_vector_dim=*/1)));
2963   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2964       << ShapeUtil::HumanString(scatter_shape);
2965 }
2966 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)2967 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
2968   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2969       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
2970       to_apply_,
2971       HloScatterInstruction::MakeScatterDimNumbers(
2972           /*update_window_dims=*/{0},
2973           /*inserted_window_dims=*/{1},
2974           /*scatter_dims_to_operand_dims=*/{1},
2975           /*index_vector_dim=*/1));
2976   ASSERT_FALSE(statusor.ok());
2977   EXPECT_THAT(
2978       statusor.status().error_message(),
2979       HasSubstr("Bounds of the window dimensions of updates must not exceed "
2980                 "the bounds of the corresponding dimensions of operand."))
2981       << statusor.status();
2982 }
2983 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)2984 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
2985   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2986       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
2987       to_apply_,
2988       HloScatterInstruction::MakeScatterDimNumbers(
2989           /*update_window_dims=*/{1},
2990           /*inserted_window_dims=*/{0},
2991           /*scatter_dims_to_operand_dims=*/{1},
2992           /*index_vector_dim=*/1));
2993   ASSERT_FALSE(statusor.ok());
2994   EXPECT_THAT(
2995       statusor.status().error_message(),
2996       HasSubstr("Bounds of the window dimensions of updates must not exceed "
2997                 "the bounds of the corresponding dimensions of operand."))
2998       << statusor.status();
2999 }
3000 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)3001 TEST_F(ScatterGatherShapeInferenceTest,
3002        TfScatterWithUpdatesNotMatchingIndices) {
3003   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3004       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
3005       to_apply_,
3006       HloScatterInstruction::MakeScatterDimNumbers(
3007           /*update_window_dims=*/{0},
3008           /*inserted_window_dims=*/{1},
3009           /*scatter_dims_to_operand_dims=*/{1},
3010           /*index_vector_dim=*/1));
3011   ASSERT_FALSE(statusor.ok());
3012   EXPECT_THAT(
3013       statusor.status().error_message(),
3014       HasSubstr(
3015           "Bounds of the scatter dimensions of updates must be same as the "
3016           "bounds of the corresponding dimensions of scatter indices."))
3017       << statusor.status();
3018 }
3019 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)3020 TEST_F(ScatterGatherShapeInferenceTest,
3021        TfScatterWithUpdatesNotMatchingIndicesV2) {
3022   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3023       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
3024       to_apply_,
3025       HloScatterInstruction::MakeScatterDimNumbers(
3026           /*update_window_dims=*/{1},
3027           /*inserted_window_dims=*/{0},
3028           /*scatter_dims_to_operand_dims=*/{1},
3029           /*index_vector_dim=*/1));
3030   ASSERT_FALSE(statusor.ok());
3031   EXPECT_THAT(
3032       statusor.status().error_message(),
3033       HasSubstr(
3034           "Bounds of the scatter dimensions of updates must be same as the "
3035           "bounds of the corresponding dimensions of scatter indices."))
3036       << statusor.status();
3037 }
3038 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdates)3039 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
3040   TF_ASSERT_OK_AND_ASSIGN(
3041       Shape scatter_shape,
3042       ShapeInference::InferScatterShape(
3043           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3044           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
3045           HloScatterInstruction::MakeScatterDimNumbers(
3046               /*update_window_dims=*/{4},
3047               /*inserted_window_dims=*/{0},
3048               /*scatter_dims_to_operand_dims=*/{0},
3049               /*index_vector_dim=*/4)));
3050   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3051       << ShapeUtil::HumanString(scatter_shape);
3052 }
3053 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdatesV2)3054 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
3055   TF_ASSERT_OK_AND_ASSIGN(
3056       Shape scatter_shape,
3057       ShapeInference::InferScatterShape(
3058           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3059           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
3060           HloScatterInstruction::MakeScatterDimNumbers(
3061               /*update_window_dims=*/{4},
3062               /*inserted_window_dims=*/{1},
3063               /*scatter_dims_to_operand_dims=*/{0},
3064               /*index_vector_dim=*/4)));
3065   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3066       << ShapeUtil::HumanString(scatter_shape);
3067 }
3068 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdates)3069 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
3070   TF_ASSERT_OK_AND_ASSIGN(
3071       Shape scatter_shape,
3072       ShapeInference::InferScatterShape(
3073           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3074           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
3075           HloScatterInstruction::MakeScatterDimNumbers(
3076               /*update_window_dims=*/{4},
3077               /*inserted_window_dims=*/{0},
3078               /*scatter_dims_to_operand_dims=*/{0},
3079               /*index_vector_dim=*/4)));
3080   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3081       << ShapeUtil::HumanString(scatter_shape);
3082 }
3083 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)3084 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
3085   TF_ASSERT_OK_AND_ASSIGN(
3086       Shape scatter_shape,
3087       ShapeInference::InferScatterShape(
3088           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3089           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
3090           HloScatterInstruction::MakeScatterDimNumbers(
3091               /*update_window_dims=*/{4},
3092               /*inserted_window_dims=*/{1},
3093               /*scatter_dims_to_operand_dims=*/{0},
3094               /*index_vector_dim=*/4)));
3095   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3096       << ShapeUtil::HumanString(scatter_shape);
3097 }
3098 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)3099 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
3100   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3101       matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3102       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
3103       HloScatterInstruction::MakeScatterDimNumbers(
3104           /*update_window_dims=*/{4},
3105           /*inserted_window_dims=*/{1},
3106           /*scatter_dims_to_operand_dims=*/{0},
3107           /*index_vector_dim=*/4));
3108   ASSERT_FALSE(statusor.ok());
3109   EXPECT_THAT(
3110       statusor.status().error_message(),
3111       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3112                 "the bounds of the corresponding dimensions of operand."))
3113       << statusor.status();
3114 }
3115 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)3116 TEST_F(ScatterGatherShapeInferenceTest,
3117        TfScatterNdWithUpdatesNotMatchingIndices) {
3118   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3119       matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3120       ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
3121       HloScatterInstruction::MakeScatterDimNumbers(
3122           /*update_window_dims=*/{4},
3123           /*inserted_window_dims=*/{1},
3124           /*scatter_dims_to_operand_dims=*/{0},
3125           /*index_vector_dim=*/4));
3126   ASSERT_FALSE(statusor.ok());
3127   EXPECT_THAT(
3128       statusor.status().error_message(),
3129       HasSubstr(
3130           "Bounds of the scatter dimensions of updates must be same as the "
3131           "bounds of the corresponding dimensions of scatter indices."))
3132       << statusor.status();
3133 }
3134 
TEST_F(ScatterGatherShapeInferenceTest,TfBatchDynamicUpdateSlice)3135 TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
3136   TF_ASSERT_OK_AND_ASSIGN(
3137       Shape scatter_shape,
3138       ShapeInference::InferScatterShape(
3139           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3140           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
3141           to_apply_,
3142           HloScatterInstruction::MakeScatterDimNumbers(
3143               /*update_window_dims=*/{4, 5, 6, 7, 8},
3144               /*inserted_window_dims=*/{},
3145               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3146               /*index_vector_dim=*/4)));
3147   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3148       << ShapeUtil::HumanString(scatter_shape);
3149 }
3150 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDim)3151 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
3152   TF_ASSERT_OK_AND_ASSIGN(
3153       Shape scatter_shape,
3154       ShapeInference::InferScatterShape(
3155           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
3156           ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3157           to_apply_,
3158           HloScatterInstruction::MakeScatterDimNumbers(
3159               /*update_window_dims=*/{4, 5, 6, 7, 8},
3160               /*inserted_window_dims=*/{},
3161               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3162               /*index_vector_dim=*/2)));
3163 
3164   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3165       << ShapeUtil::HumanString(scatter_shape);
3166 }
3167 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)3168 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
3169   TF_ASSERT_OK_AND_ASSIGN(
3170       Shape scatter_shape,
3171       ShapeInference::InferScatterShape(
3172           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
3173           ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3174           to_apply_,
3175           HloScatterInstruction::MakeScatterDimNumbers(
3176               /*update_window_dims=*/{4, 5, 6, 7, 8},
3177               /*inserted_window_dims=*/{},
3178               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3179               /*index_vector_dim=*/0)));
3180 
3181   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3182       << ShapeUtil::HumanString(scatter_shape);
3183 }
3184 
TEST_F(ScatterGatherShapeInferenceTest,NoUpdateScatterDims)3185 TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
3186   // This is equivalent to a dynamic update slice.
3187   TF_ASSERT_OK_AND_ASSIGN(
3188       Shape scatter_shape,
3189       ShapeInference::InferScatterShape(
3190           f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
3191           ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
3192           HloScatterInstruction::MakeScatterDimNumbers(
3193               /*update_window_dims=*/{0, 1, 2, 3, 4},
3194               /*inserted_window_dims=*/{},
3195               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3196               /*index_vector_dim=*/0)));
3197 
3198   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3199       << ShapeUtil::HumanString(scatter_shape);
3200 }
3201 
TEST_F(ScatterGatherShapeInferenceTest,ScalarScatterIndices)3202 TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
3203   // The scalar indices "tensor" is a scalar S here that's used to update a
3204   // [30,29,28,27] shaped tensor within the operand at position S.
3205   TF_ASSERT_OK_AND_ASSIGN(
3206       Shape scatter_shape,
3207       ShapeInference::InferScatterShape(
3208           f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3209           ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3210           HloScatterInstruction::MakeScatterDimNumbers(
3211               /*update_window_dims=*/{0, 1, 2, 3},
3212               /*inserted_window_dims=*/{0},
3213               /*scatter_dims_to_operand_dims=*/{0},
3214               /*index_vector_dim=*/0)));
3215 
3216   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3217       << ShapeUtil::HumanString(scatter_shape);
3218 }
3219 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedTensorInput)3220 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
3221   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3222       tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
3223       HloScatterInstruction::MakeScatterDimNumbers(
3224           /*update_window_dims=*/{0},
3225           /*inserted_window_dims=*/{1},
3226           /*scatter_dims_to_operand_dims=*/{1},
3227           /*index_vector_dim=*/1));
3228   ASSERT_FALSE(statusor.ok());
3229   EXPECT_THAT(statusor.status().error_message(),
3230               HasSubstr("Expected array argument for operand"))
3231       << statusor.status();
3232 }
3233 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)3234 TEST_F(ScatterGatherShapeInferenceTest,
3235        ScatterWithTupleShapedScatterIndicesInput) {
3236   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3237       s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
3238       HloScatterInstruction::MakeScatterDimNumbers(
3239           /*update_window_dims=*/{0},
3240           /*inserted_window_dims=*/{1},
3241           /*scatter_dims_to_operand_dims=*/{1},
3242           /*index_vector_dim=*/0));
3243   ASSERT_FALSE(statusor.ok());
3244   EXPECT_THAT(statusor.status().error_message(),
3245               HasSubstr("Expected array argument for scatter indices"))
3246       << statusor.status();
3247 }
3248 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)3249 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
3250   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3251       s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
3252       HloScatterInstruction::MakeScatterDimNumbers(
3253           /*update_window_dims=*/{0},
3254           /*inserted_window_dims=*/{1},
3255           /*scatter_dims_to_operand_dims=*/{1},
3256           /*index_vector_dim=*/0));
3257   ASSERT_FALSE(statusor.ok());
3258   EXPECT_THAT(statusor.status().error_message(),
3259               HasSubstr("Expected array argument for updates"))
3260       << statusor.status();
3261 }
3262 
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointScatterIndicesInput)3263 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
3264   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3265       s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
3266       HloScatterInstruction::MakeScatterDimNumbers(
3267           /*update_window_dims=*/{0},
3268           /*inserted_window_dims=*/{1},
3269           /*scatter_dims_to_operand_dims=*/{1},
3270           /*index_vector_dim=*/0));
3271   ASSERT_FALSE(statusor.ok());
3272   EXPECT_THAT(statusor.status().error_message(),
3273               HasSubstr("Scatter indices parameter must be an integral tensor"))
3274       << statusor.status();
3275 }
3276 
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)3277 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
3278   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3279       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3280       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3281       HloScatterInstruction::MakeScatterDimNumbers(
3282           /*update_window_dims=*/{4, 5, 6},
3283           /*inserted_window_dims=*/{1, 2},
3284           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3285           /*index_vector_dim=*/10));
3286   ASSERT_FALSE(statusor.ok());
3287   EXPECT_THAT(statusor.status().error_message(),
3288               HasSubstr("Scatter index leaf dimension must be within [0, "
3289                         "rank(scatter_indices) + 1)"))
3290       << statusor.status();
3291 }
3292 
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdates)3293 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
3294   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3295       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3296       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
3297       HloScatterInstruction::MakeScatterDimNumbers(
3298           /*update_window_dims=*/{4, 5, 6},
3299           /*inserted_window_dims=*/{1, 2},
3300           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3301           /*index_vector_dim=*/4));
3302   ASSERT_FALSE(statusor.ok());
3303   EXPECT_THAT(statusor.status().error_message(),
3304               HasSubstr("Updates tensor must be of rank 7; got 8."))
3305       << statusor.status();
3306 }
3307 
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdateComputation)3308 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
3309   const ProgramShape invalid_update_computation =
3310       ShapeUtil::MakeProgramShape({f32_}, f32_);
3311   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3312       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3313       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
3314       invalid_update_computation,
3315       HloScatterInstruction::MakeScatterDimNumbers(
3316           /*update_window_dims=*/{4, 5, 6},
3317           /*inserted_window_dims=*/{1, 2},
3318           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3319           /*index_vector_dim=*/4));
3320   ASSERT_FALSE(statusor.ok());
3321   EXPECT_THAT(
3322       statusor.status().error_message(),
3323       HasSubstr("Reduction function must take 2 parameters, but takes 1"))
3324       << statusor.status();
3325 }
3326 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)3327 TEST_F(ScatterGatherShapeInferenceTest,
3328        InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
3329   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3330       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3331       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3332       HloScatterInstruction::MakeScatterDimNumbers(
3333           /*update_window_dims=*/{4, 5, 6, 8, 7},
3334           /*inserted_window_dims=*/{},
3335           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3336           /*index_vector_dim=*/4));
3337   ASSERT_FALSE(statusor.ok());
3338   EXPECT_THAT(statusor.status().error_message(),
3339               HasSubstr("update_window_dims in scatter op must be sorted"))
3340       << statusor.status();
3341 }
3342 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)3343 TEST_F(ScatterGatherShapeInferenceTest,
3344        InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
3345   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3346       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3347       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3348       HloScatterInstruction::MakeScatterDimNumbers(
3349           /*update_window_dims=*/{4, 5, 6, 7, 7},
3350           /*inserted_window_dims=*/{},
3351           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3352           /*index_vector_dim=*/4));
3353   ASSERT_FALSE(statusor.ok());
3354   EXPECT_THAT(statusor.status().error_message(),
3355               HasSubstr("update_window_dims in scatter op must not repeat"))
3356       << statusor.status();
3357 }
3358 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)3359 TEST_F(ScatterGatherShapeInferenceTest,
3360        InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
3361   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3362       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3363       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3364       HloScatterInstruction::MakeScatterDimNumbers(
3365           /*update_window_dims=*/{4, 5, 6, 7, 9},
3366           /*inserted_window_dims=*/{},
3367           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3368           /*index_vector_dim=*/4));
3369   ASSERT_FALSE(statusor.ok());
3370   EXPECT_THAT(statusor.status().error_message(),
3371               HasSubstr("Invalid update_window_dims set in scatter op; valid "
3372                         "range is [0, 9)"))
3373       << statusor.status();
3374 }
3375 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)3376 TEST_F(ScatterGatherShapeInferenceTest,
3377        InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
3378   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3379       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3380       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3381       HloScatterInstruction::MakeScatterDimNumbers(
3382           /*update_window_dims=*/{4, 5, 6},
3383           /*inserted_window_dims=*/{2, 1},
3384           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3385           /*index_vector_dim=*/4));
3386   ASSERT_FALSE(statusor.ok());
3387   EXPECT_THAT(statusor.status().error_message(),
3388               HasSubstr("inserted_window_dims in scatter op must be sorted"))
3389       << statusor.status();
3390 }
3391 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)3392 TEST_F(ScatterGatherShapeInferenceTest,
3393        InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
3394   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3395       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3396       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3397       HloScatterInstruction::MakeScatterDimNumbers(
3398           /*update_window_dims=*/{4, 5, 6},
3399           /*inserted_window_dims=*/{1, 1},
3400           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3401           /*index_vector_dim=*/4));
3402   ASSERT_FALSE(statusor.ok());
3403   EXPECT_THAT(statusor.status().error_message(),
3404               HasSubstr("inserted_window_dims in scatter op must not repeat"))
3405       << statusor.status();
3406 }
3407 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)3408 TEST_F(ScatterGatherShapeInferenceTest,
3409        InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
3410   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3411       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3412       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3413       HloScatterInstruction::MakeScatterDimNumbers(
3414           /*update_window_dims=*/{4, 5, 6},
3415           /*inserted_window_dims=*/{1, 5},
3416           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3417           /*index_vector_dim=*/4));
3418   ASSERT_FALSE(statusor.ok());
3419   EXPECT_THAT(statusor.status().error_message(),
3420               HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
3421                         "range is [0, 5)"))
3422       << statusor.status();
3423 }
3424 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)3425 TEST_F(ScatterGatherShapeInferenceTest,
3426        InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
3427   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3428       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3429       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3430       HloScatterInstruction::MakeScatterDimNumbers(
3431           /*update_window_dims=*/{4, 5, 6},
3432           /*inserted_window_dims=*/{1, 2},
3433           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
3434           /*index_vector_dim=*/4));
3435   ASSERT_FALSE(statusor.ok());
3436   EXPECT_THAT(
3437       statusor.status().error_message(),
3438       HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
3439                 "the bound of dimension index_vector_dim=4 of scatter_indices "
3440                 "is 5. These two numbers must be equal"))
3441       << statusor.status();
3442 }
3443 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)3444 TEST_F(ScatterGatherShapeInferenceTest,
3445        InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
3446   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3447       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3448       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3449       HloScatterInstruction::MakeScatterDimNumbers(
3450           /*update_window_dims=*/{4, 5, 6},
3451           /*inserted_window_dims=*/{1, 2},
3452           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
3453           /*index_vector_dim=*/4));
3454   ASSERT_FALSE(statusor.ok());
3455   EXPECT_THAT(statusor.status().error_message(),
3456               HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
3457                         "is [0, 5), got: 4->10"))
3458       << statusor.status();
3459 }
3460 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)3461 TEST_F(ScatterGatherShapeInferenceTest,
3462        InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
3463   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3464       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3465       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3466       HloScatterInstruction::MakeScatterDimNumbers(
3467           /*update_window_dims=*/{4, 5, 6},
3468           /*inserted_window_dims=*/{1, 2},
3469           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
3470           /*index_vector_dim=*/4));
3471   ASSERT_FALSE(statusor.ok());
3472   EXPECT_THAT(
3473       statusor.status().error_message(),
3474       HasSubstr(
3475           "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
3476       << statusor.status();
3477 }
3478 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)3479 TEST_F(ScatterGatherShapeInferenceTest,
3480        InvalidScatterDimNumbers_InsufficientWindowDims) {
3481   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3482       f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3483       ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3484       HloScatterInstruction::MakeScatterDimNumbers(
3485           /*update_window_dims=*/{0, 1, 2, 3},
3486           /*inserted_window_dims=*/{},
3487           /*scatter_dims_to_operand_dims=*/{0},
3488           /*index_vector_dim=*/0));
3489   ASSERT_FALSE(statusor.ok());
3490   EXPECT_THAT(
3491       statusor.status().error_message(),
3492       HasSubstr(
3493           "Scatter op has window of size 4; doesn't match operand of rank 5."))
3494       << statusor.status();
3495 }
3496 
3497 }  // namespace
3498 }  // namespace xla
3499