• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/padding.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/test.h"
26 #include "tensorflow/compiler/xla/test_helpers.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 
30 namespace xla {
31 namespace {
32 
33 using ::testing::ContainsRegex;
34 using ::testing::HasSubstr;
35 
36 class ShapeInferenceTest : public ::testing::Test {
37  protected:
38   // Some handy scalar shapes.
39   const Shape s32_ = ShapeUtil::MakeShape(S32, {});
40   const Shape f16_ = ShapeUtil::MakeShape(F16, {});
41   const Shape f32_ = ShapeUtil::MakeShape(F32, {});
42   const Shape f64_ = ShapeUtil::MakeShape(F64, {});
43   const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
44 
45   // Some handy vector and matrix shapes of F32 type.
46   // Suffix: vector_length_, matrix_rows_cols_
47   const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
48   const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
49   const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
50   const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
51   const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
52 
53   // Some handy S32 arrays.
54   const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
55 };
56 
57 // Subclass for testing InferReduceShape.
58 class ReduceShapeInferenceTest : public ShapeInferenceTest {
59  protected:
60   // Helper that runs reduce shape inference with the input 'arg' and given
61   // dimensions to reduce, and checks the inferred shape is as expected. The
62   // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64> dimensions_to_reduce)63   void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
64                                  const Shape& arg,
65                                  absl::Span<const int64> dimensions_to_reduce) {
66     ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
67     auto inferred_status = ShapeInference::InferReduceShape(
68         {&arg, &f32_}, dimensions_to_reduce, to_apply);
69     EXPECT_IS_OK(inferred_status.status());
70     EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
71                                  inferred_status.ValueOrDie()));
72   }
73 };
74 
75 // Subclass for testing InferSelectAndScatterShape.
76 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
77  protected:
SelectAndScatterShapeInferenceTest()78   SelectAndScatterShapeInferenceTest() {
79     operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
80     source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
81     WindowDimension dim;
82     dim.set_size(2);
83     dim.set_stride(2);
84     dim.set_padding_low(0);
85     dim.set_padding_high(0);
86     dim.set_window_dilation(1);
87     dim.set_base_dilation(1);
88     *window_.add_dimensions() = dim;
89     *window_.add_dimensions() = dim;
90     init_value_shape_ = ShapeUtil::MakeShape(F32, {});
91     select_program_shape_ = ShapeUtil::MakeProgramShape(
92         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
93     scatter_program_shape_ = ShapeUtil::MakeProgramShape(
94         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
95   }
96 
97   Shape operand_shape_;
98   Shape source_shape_;
99   Window window_;
100   Shape init_value_shape_;
101   ProgramShape select_program_shape_;
102   ProgramShape scatter_program_shape_;
103 };
104 
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)105 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
106   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
107   auto inferred_status =
108       ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
109   ASSERT_IS_OK(inferred_status.status());
110   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
111 }
112 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)113 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
114   Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
115   auto inferred_status = ShapeInference::InferTernaryOpShape(
116       HloOpcode::kSelect, pred_, tuple, tuple);
117   ASSERT_FALSE(inferred_status.ok());
118   ASSERT_THAT(inferred_status.status().error_message(),
119               HasSubstr("Expected array argument for select"));
120 }
121 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)122 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
123   auto inferred_status = ShapeInference::InferTernaryOpShape(
124       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
125   ASSERT_FALSE(inferred_status.ok());
126   ASSERT_THAT(
127       inferred_status.status().error_message(),
128       HasSubstr("Operands to select and predicate must be the same shape"));
129 }
130 
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)131 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
132   auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
133   auto inferred_status = ShapeInference::InferTernaryOpShape(
134       HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
135   ASSERT_IS_OK(inferred_status.status());
136   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
137 }
138 
TEST_F(ShapeInferenceTest,SelectBadShapes)139 TEST_F(ShapeInferenceTest, SelectBadShapes) {
140   auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
141       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
142   ASSERT_FALSE(inferred_status_error1.ok());
143   ASSERT_THAT(inferred_status_error1.status().error_message(),
144               HasSubstr("Operands to select must be the same shape"));
145 
146   auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
147       HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
148   ASSERT_FALSE(inferred_status_error2.ok());
149   ASSERT_THAT(inferred_status_error2.status().error_message(),
150               HasSubstr("pred operand must have PRED"));
151 
152   auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
153       HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
154       matrix_64_48_);
155   ASSERT_FALSE(inferred_status_error3.ok());
156   ASSERT_THAT(
157       inferred_status_error3.status().error_message(),
158       HasSubstr("Operands to select and predicate must be the same shape"));
159 
160   // Tuples have a TUPLE element type and cannot be the pred of a select.
161   auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
162       HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
163       ShapeUtil::MakeTupleShape({f32_, f32_}),
164       ShapeUtil::MakeTupleShape({f32_, f32_}));
165   ASSERT_FALSE(inferred_status_error4.ok());
166   ASSERT_THAT(inferred_status_error4.status().error_message(),
167               HasSubstr("Expected array argument for select pred"));
168 }
169 
TEST_F(ShapeInferenceTest,ClampAllMatrix)170 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
171   auto inferred_status = ShapeInference::InferTernaryOpShape(
172       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
173   ASSERT_IS_OK(inferred_status.status());
174   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
175 }
176 
TEST_F(ShapeInferenceTest,ClampAllScalar)177 TEST_F(ShapeInferenceTest, ClampAllScalar) {
178   auto inferred_status =
179       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
180   ASSERT_IS_OK(inferred_status.status());
181   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
182 }
183 
TEST_F(ShapeInferenceTest,ClampMinScalar)184 TEST_F(ShapeInferenceTest, ClampMinScalar) {
185   auto inferred_status = ShapeInference::InferTernaryOpShape(
186       HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
187   ASSERT_FALSE(inferred_status.ok());
188   ASSERT_THAT(inferred_status.status().error_message(),
189               HasSubstr("Clamp with different shapes"));
190 }
191 
TEST_F(ShapeInferenceTest,ClampMaxScalar)192 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
193   auto inferred_status = ShapeInference::InferTernaryOpShape(
194       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
195   ASSERT_FALSE(inferred_status.ok());
196   ASSERT_THAT(inferred_status.status().error_message(),
197               HasSubstr("Clamp with different shapes"));
198 }
199 
TEST_F(ShapeInferenceTest,ClampOperandScalar)200 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
201   auto inferred_status = ShapeInference::InferTernaryOpShape(
202       HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
203   ASSERT_FALSE(inferred_status.ok());
204   ASSERT_THAT(inferred_status.status().error_message(),
205               HasSubstr("Clamp with different shapes"));
206 }
207 
TEST_F(ShapeInferenceTest,ClampMinMatrix)208 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
209   auto inferred_status = ShapeInference::InferTernaryOpShape(
210       HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
211   ASSERT_FALSE(inferred_status.ok());
212   ASSERT_THAT(inferred_status.status().error_message(),
213               HasSubstr("Clamp with different shapes"));
214 }
215 
TEST_F(ShapeInferenceTest,ClampMaxMatrix)216 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
217   auto inferred_status = ShapeInference::InferTernaryOpShape(
218       HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
219   ASSERT_FALSE(inferred_status.ok());
220   ASSERT_THAT(inferred_status.status().error_message(),
221               HasSubstr("Clamp with different shapes"));
222 }
223 
TEST_F(ShapeInferenceTest,ClampOperandMatrix)224 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
225   auto inferred_status = ShapeInference::InferTernaryOpShape(
226       HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
227   ASSERT_FALSE(inferred_status.ok());
228   ASSERT_THAT(inferred_status.status().error_message(),
229               HasSubstr("Clamp with different shapes"));
230 }
231 
TEST_F(ShapeInferenceTest,ClampBadShapes)232 TEST_F(ShapeInferenceTest, ClampBadShapes) {
233   // Type mismatch
234   ASSERT_FALSE(
235       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
236           .ok());
237   ASSERT_FALSE(
238       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
239           .ok());
240   ASSERT_FALSE(
241       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
242           .ok());
243   // Dimension mismatch
244   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
245                    HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
246                    .ok());
247   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
248                    HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
249                    .ok());
250   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
251                    HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
252                    .ok());
253   // Dimension mismatch, where one operand is a scalar
254   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
255                                                    vector_64_, vector_32_, f32_)
256                    .ok());
257   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
258                                                    vector_64_, f32_, vector_32_)
259                    .ok());
260   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
261                                                    vector_64_, vector_32_)
262                    .ok());
263 }
264 
TEST_F(ShapeInferenceTest,Complex)265 TEST_F(ShapeInferenceTest, Complex) {
266   auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
267                            absl::Span<const int64> bcast) {
268     return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
269                                               bcast);
270   };
271   // Inputs must be FP.
272   ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
273   ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
274   // Component types must match.
275   ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
276   // Only F32->C64 and F64->C128 supported.
277   ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
278   // Validate correct uses.
279   Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
280   TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
281   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
282   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
283   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
284   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
285   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
286   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
287   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
288 
289   Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
290   TF_ASSERT_OK_AND_ASSIGN(result,
291                           complex_shape(vector_64_, matrix_32_64_, {1}));
292   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
293   TF_ASSERT_OK_AND_ASSIGN(result,
294                           complex_shape(matrix_32_64_, vector_64_, {1}));
295   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
296   TF_ASSERT_OK_AND_ASSIGN(result,
297                           complex_shape(matrix_32_64_, matrix_32_64_, {}));
298   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
299   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
300   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
301 
302   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
303   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
304 }
305 
TEST_F(ShapeInferenceTest,VariadicOpTuplify)306 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
307   StatusOr<Shape> result =
308       ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
309   ASSERT_IS_OK(result.status());
310   ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
311                                ShapeUtil::MakeTupleShape({s32_, f32_})));
312 }
313 
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)314 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
315   Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
316   Window window;
317   WindowDimension dim;
318   dim.set_size(2);
319   dim.set_stride(2);
320   dim.set_padding_low(0);
321   dim.set_padding_high(0);
322   dim.set_window_dilation(1);
323   dim.set_base_dilation(1);
324   *window.add_dimensions() = dim;
325   *window.add_dimensions() = dim;
326   Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
327   Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
328   Shape float_scalar = ShapeUtil::MakeShape(F32, {});
329   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
330       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
331   auto inferred_status = ShapeInference::InferReduceWindowShape(
332       matrix_shape, init_value_shape, window, to_apply);
333 
334   ASSERT_IS_OK(inferred_status.status());
335   Shape inferred = inferred_status.ValueOrDie();
336   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
337 }
338 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)339 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
340   auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
341       operand_shape_, select_program_shape_, window_, source_shape_,
342       init_value_shape_, scatter_program_shape_);
343   ASSERT_IS_OK(inferred_status_ok.status());
344   Shape inferred = inferred_status_ok.ValueOrDie();
345   ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
346 }
347 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)348 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
349   Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
350   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
351       operand_shape_, select_program_shape_, window_, source_shape_fail,
352       init_value_shape_, scatter_program_shape_);
353   ASSERT_FALSE(inferred_status_fail.ok());
354   ASSERT_THAT(inferred_status_fail.status().error_message(),
355               HasSubstr("Source shape does not match"));
356 }
357 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)358 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
359   ProgramShape select_program_shape_fail =
360       ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
361   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
362       operand_shape_, select_program_shape_fail, window_, source_shape_,
363       init_value_shape_, scatter_program_shape_);
364   ASSERT_FALSE(inferred_status_fail.ok());
365   ASSERT_THAT(inferred_status_fail.status().error_message(),
366               HasSubstr("Select function must take 2 parameters"));
367 }
368 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)369 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
370   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
371       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
372   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
373       operand_shape_, select_program_shape_fail, window_, source_shape_,
374       init_value_shape_, scatter_program_shape_);
375   ASSERT_FALSE(inferred_status_fail.ok());
376   ASSERT_THAT(inferred_status_fail.status().error_message(),
377               HasSubstr("Select function must have rank-0 PRED"));
378 }
379 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)380 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
381   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
382       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
383   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
384       operand_shape_, select_program_shape_fail, window_, source_shape_,
385       init_value_shape_, scatter_program_shape_);
386   ASSERT_FALSE(inferred_status_fail.ok());
387   ASSERT_THAT(inferred_status_fail.status().error_message(),
388               HasSubstr("Select function's first parameter"));
389 }
390 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)391 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
392   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
393       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
394   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
395       operand_shape_, select_program_shape_fail, window_, source_shape_,
396       init_value_shape_, scatter_program_shape_);
397   ASSERT_FALSE(inferred_status_fail.ok());
398   ASSERT_THAT(inferred_status_fail.status().error_message(),
399               HasSubstr("Select function's second parameter"));
400 }
401 
TEST_F(ShapeInferenceTest,AllGatherStart)402 TEST_F(ShapeInferenceTest, AllGatherStart) {
403   const Shape operand = ShapeUtil::MakeShape(F32, {1, 8, 4});
404   const Shape expected_shape = ShapeUtil::MakeTupleShape(
405       {operand, ShapeUtil::MakeShape(F32, {8, 8, 4})});
406 
407   auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape(
408       {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8);
409   EXPECT_TRUE(inferred_ag_shape.ok());
410   EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.ValueOrDie(), expected_shape));
411 }
412 
TEST_F(ShapeInferenceTest,AllGatherDone)413 TEST_F(ShapeInferenceTest, AllGatherDone) {
414   const Shape input_shape =
415       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1, 8, 4}),
416                                  ShapeUtil::MakeShape(F32, {8, 8, 4})});
417   const Shape expected_shape = ShapeUtil::MakeShape(F32, {8, 8, 4});
418 
419   auto inferred_ag_done_shape =
420       ShapeInference::InferAllGatherDoneShape(input_shape);
421   EXPECT_TRUE(inferred_ag_done_shape.ok());
422   EXPECT_TRUE(
423       ShapeUtil::Equal(inferred_ag_done_shape.ValueOrDie(), expected_shape));
424 }
425 
TEST_F(ShapeInferenceTest,Convolve)426 TEST_F(ShapeInferenceTest, Convolve) {
427   ConvolutionDimensionNumbers dnums;
428 
429   // Dimension order: batch, feature, x0, x1
430   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
431   dnums.set_input_batch_dimension(0);
432   dnums.set_output_batch_dimension(0);
433   dnums.set_input_feature_dimension(1);
434   dnums.set_output_feature_dimension(1);
435   dnums.add_input_spatial_dimensions(2);
436   dnums.add_output_spatial_dimensions(2);
437   dnums.add_input_spatial_dimensions(3);
438   dnums.add_output_spatial_dimensions(3);
439 
440   // Dimension order: x1, batch, feature, x0
441   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
442   dnums.set_kernel_input_feature_dimension(2);
443   dnums.set_kernel_output_feature_dimension(1);
444   dnums.add_kernel_spatial_dimensions(3);
445   dnums.add_kernel_spatial_dimensions(0);
446 
447   Window window;
448   auto dim0 = window.add_dimensions();
449   auto dim1 = window.add_dimensions();
450   dim0->set_size(3);
451   dim0->set_stride(2);
452   dim0->set_padding_low(1);
453   dim0->set_padding_high(1);
454   dim0->set_window_dilation(1);
455   dim0->set_base_dilation(1);
456   dim1->set_size(2);
457   dim1->set_stride(1);
458   dim1->set_padding_low(0);
459   dim1->set_padding_high(0);
460   dim1->set_window_dilation(1);
461   dim1->set_base_dilation(1);
462   auto inferred_status = ShapeInference::InferConvolveShape(
463       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
464       window, dnums, /*preferred_element_type=*/absl::nullopt);
465   ASSERT_IS_OK(inferred_status.status());
466   Shape inferred_shape = inferred_status.ValueOrDie();
467   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
468                                inferred_shape));
469 }
470 
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)471 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
472   ConvolutionDimensionNumbers dnums;
473 
474   // Dimension order: batch, feature, x0, x1
475   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
476   dnums.set_input_batch_dimension(0);
477   dnums.set_output_batch_dimension(0);
478   dnums.set_input_feature_dimension(1);
479   dnums.set_output_feature_dimension(1);
480   dnums.add_input_spatial_dimensions(2);
481   dnums.add_output_spatial_dimensions(2);
482   dnums.add_input_spatial_dimensions(3);
483   dnums.add_output_spatial_dimensions(3);
484 
485   // Dimension order: x1, batch, feature, x0
486   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
487   dnums.set_kernel_input_feature_dimension(2);
488   dnums.set_kernel_output_feature_dimension(1);
489   dnums.add_kernel_spatial_dimensions(3);
490   dnums.add_kernel_spatial_dimensions(0);
491 
492   Window window;
493   auto dim0 = window.add_dimensions();
494   dim0->set_size(3);
495   dim0->set_stride(3);
496   dim0->set_padding_low(0);
497   dim0->set_padding_high(0);
498   dim0->set_window_dilation(6);
499   dim0->set_base_dilation(1);
500 
501   auto dim1 = window.add_dimensions();
502   dim1->set_size(2);
503   dim1->set_stride(1);
504   dim1->set_padding_low(2);
505   dim1->set_padding_high(1);
506   dim1->set_window_dilation(2);
507   dim1->set_base_dilation(1);
508   auto inferred_status = ShapeInference::InferConvolveShape(
509       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
510       window, dnums, /*preferred_element_type=*/absl::nullopt);
511   ASSERT_IS_OK(inferred_status.status());
512   Shape inferred_shape = inferred_status.ValueOrDie();
513   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
514                                inferred_shape));
515 }
516 
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)517 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
518   ConvolutionDimensionNumbers dnums;
519 
520   // Dimension order: batch, feature, x0, x1
521   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
522   dnums.set_input_batch_dimension(0);
523   dnums.set_output_batch_dimension(0);
524   dnums.set_input_feature_dimension(1);
525   dnums.set_output_feature_dimension(1);
526   dnums.add_input_spatial_dimensions(2);
527   dnums.add_output_spatial_dimensions(2);
528   dnums.add_input_spatial_dimensions(3);
529   dnums.add_output_spatial_dimensions(3);
530 
531   // Dimension order: x1, batch, feature, x0
532   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
533   dnums.set_kernel_input_feature_dimension(2);
534   dnums.set_kernel_output_feature_dimension(1);
535   dnums.add_kernel_spatial_dimensions(3);
536   dnums.add_kernel_spatial_dimensions(0);
537 
538   Window window;
539   auto dim0 = window.add_dimensions();
540   dim0->set_size(4);
541   dim0->set_stride(3);
542   dim0->set_padding_low(0);
543   dim0->set_padding_high(0);
544   dim0->set_window_dilation(1);
545   dim0->set_base_dilation(6);
546 
547   auto dim1 = window.add_dimensions();
548   dim1->set_size(2);
549   dim1->set_stride(1);
550   dim1->set_padding_low(2);
551   dim1->set_padding_high(1);
552   dim1->set_window_dilation(1);
553   dim1->set_base_dilation(2);
554   auto inferred_status = ShapeInference::InferConvolveShape(
555       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
556       window, dnums, /*preferred_element_type=*/absl::nullopt);
557   ASSERT_IS_OK(inferred_status.status());
558   Shape inferred_shape = inferred_status.ValueOrDie();
559   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
560                                inferred_shape));
561 }
562 
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)563 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
564   // Dimension order for this test: batch, feature, x0, x1
565   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
566   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
567 
568   ConvolutionDimensionNumbers dnums;
569   dnums.set_input_batch_dimension(3);
570   dnums.set_output_batch_dimension(3);
571   dnums.set_input_feature_dimension(2);
572   dnums.set_output_feature_dimension(2);
573   dnums.add_input_spatial_dimensions(0);
574   dnums.add_output_spatial_dimensions(0);
575   dnums.add_input_spatial_dimensions(1);
576   dnums.add_output_spatial_dimensions(1);
577   dnums.set_kernel_input_feature_dimension(0);  // duplicated with kernel_x0
578   dnums.set_kernel_output_feature_dimension(3);
579   dnums.add_kernel_spatial_dimensions(0);
580   dnums.add_kernel_spatial_dimensions(1);
581 
582   Window window;
583   auto dim0 = window.add_dimensions();
584   auto dim1 = window.add_dimensions();
585   dim0->set_size(2);
586   dim0->set_stride(1);
587   dim0->set_padding_low(0);
588   dim0->set_padding_high(0);
589   dim1->set_size(3);
590   dim1->set_stride(2);
591   dim1->set_padding_low(1);
592   dim1->set_padding_high(1);
593   auto inferred_status = ShapeInference::InferConvolveShape(
594       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
595       window, dnums, /*preferred_element_type=*/absl::nullopt);
596   ASSERT_FALSE(inferred_status.ok());
597   ASSERT_THAT(inferred_status.status().error_message(),
598               HasSubstr("each dimension exactly once"));
599 }
600 
TEST_F(ShapeInferenceTest,ConvolveBatchGroupCountUnequalOutputFeature)601 TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
602   ConvolutionDimensionNumbers dnums;
603   dnums.set_input_batch_dimension(0);
604   dnums.set_input_feature_dimension(1);
605   dnums.add_input_spatial_dimensions(2);
606   dnums.add_input_spatial_dimensions(3);
607   dnums.set_kernel_input_feature_dimension(0);
608   dnums.set_kernel_output_feature_dimension(1);
609   dnums.add_kernel_spatial_dimensions(2);
610   dnums.add_kernel_spatial_dimensions(3);
611   dnums.set_output_batch_dimension(0);
612   dnums.set_output_feature_dimension(1);
613   dnums.add_output_spatial_dimensions(2);
614   dnums.add_output_spatial_dimensions(3);
615   Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
616   Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
617   Window window;
618   auto dim0 = window.add_dimensions();
619   auto dim1 = window.add_dimensions();
620   dim0->set_size(4);
621   dim1->set_size(4);
622   dim0->set_padding_low(0);
623   dim0->set_padding_high(2);
624   dim1->set_padding_low(2);
625   dim1->set_padding_high(1);
626   dim0->set_stride(1);
627   dim1->set_stride(1);
628   dim0->set_window_dilation(3);
629   dim1->set_window_dilation(2);
630   auto inferred_status = ShapeInference::InferConvolveShape(
631       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
632       window, dnums, /*preferred_element_type=*/absl::nullopt);
633   ASSERT_FALSE(inferred_status.ok());
634   ASSERT_THAT(inferred_status.status().error_message(),
635               HasSubstr("to be a multiple of batch group count"));
636 }
637 
638 struct ConvolveArgs {
639   Shape lhs_shape;
640   Shape rhs_shape;
641   ConvolutionDimensionNumbers dnums;
642   Window window;
643 };
644 
MakeConvolveArgs(PrimitiveType lhs_type,PrimitiveType rhs_type)645 ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
646   ConvolveArgs args;
647   ConvolutionDimensionNumbers& dnums = args.dnums;
648 
649   // Dimension order: batch, feature, x0, x1
650   args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
651   dnums.set_input_batch_dimension(0);
652   dnums.set_output_batch_dimension(0);
653   dnums.set_input_feature_dimension(1);
654   dnums.set_output_feature_dimension(1);
655   dnums.add_input_spatial_dimensions(2);
656   dnums.add_output_spatial_dimensions(2);
657   dnums.add_input_spatial_dimensions(3);
658   dnums.add_output_spatial_dimensions(3);
659 
660   // Dimension order: x1, batch, feature, x0
661   args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
662   dnums.set_kernel_input_feature_dimension(2);
663   dnums.set_kernel_output_feature_dimension(1);
664   dnums.add_kernel_spatial_dimensions(3);
665   dnums.add_kernel_spatial_dimensions(0);
666 
667   auto dim0 = args.window.add_dimensions();
668   auto dim1 = args.window.add_dimensions();
669   dim0->set_size(3);
670   dim0->set_stride(2);
671   dim0->set_padding_low(1);
672   dim0->set_padding_high(1);
673   dim0->set_window_dilation(1);
674   dim0->set_base_dilation(1);
675   dim1->set_size(2);
676   dim1->set_stride(1);
677   dim1->set_padding_low(0);
678   dim1->set_padding_high(0);
679   dim1->set_window_dilation(1);
680   dim1->set_base_dilation(1);
681   return args;
682 }
683 
TEST_F(ShapeInferenceTest,ConvolveWithBF16_F16)684 TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
685   ConvolveArgs args = MakeConvolveArgs(BF16, F16);
686   TF_ASSERT_OK_AND_ASSIGN(
687       Shape inferred_shape,
688       ShapeInference::InferConvolveShape(
689           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
690           /*batch_group_count=*/1, args.window, args.dnums,
691           /*preferred_element_type=*/absl::nullopt))
692   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
693                                inferred_shape));
694 }
695 
TEST_F(ShapeInferenceTest,ConvolveWithF16_BF16)696 TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
697   ConvolveArgs args = MakeConvolveArgs(F16, BF16);
698   TF_ASSERT_OK_AND_ASSIGN(
699       Shape inferred_shape,
700       ShapeInference::InferConvolveShape(
701           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
702           /*batch_group_count=*/1, args.window, args.dnums,
703           /*preferred_element_type=*/absl::nullopt))
704   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
705                                inferred_shape));
706 }
707 
TEST_F(ShapeInferenceTest,ConvolveWithS32_U32)708 TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
709   ConvolveArgs args = MakeConvolveArgs(S32, U32);
710   TF_ASSERT_OK_AND_ASSIGN(
711       Shape inferred_shape,
712       ShapeInference::InferConvolveShape(
713           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
714           /*batch_group_count=*/1, args.window, args.dnums,
715           /*preferred_element_type=*/absl::nullopt))
716   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
717                                inferred_shape));
718 }
719 
TEST_F(ShapeInferenceTest,ConvolveWithU32_S32)720 TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
721   ConvolveArgs args = MakeConvolveArgs(U32, S32);
722   TF_ASSERT_OK_AND_ASSIGN(
723       Shape inferred_shape,
724       ShapeInference::InferConvolveShape(
725           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
726           /*batch_group_count=*/1, args.window, args.dnums,
727           /*preferred_element_type=*/absl::nullopt))
728   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
729                                inferred_shape));
730 }
731 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementType)732 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
733   ConvolveArgs args = MakeConvolveArgs(S8, S16);
734   TF_ASSERT_OK_AND_ASSIGN(
735       Shape inferred_shape,
736       ShapeInference::InferConvolveShape(
737           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
738           /*batch_group_count=*/1, args.window, args.dnums,
739           /*preferred_element_type=*/S16))
740   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
741                                inferred_shape));
742 }
743 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeSameAsInferredType)744 TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
745   ConvolveArgs args = MakeConvolveArgs(S8, S16);
746   TF_ASSERT_OK_AND_ASSIGN(
747       Shape inferred_shape,
748       ShapeInference::InferConvolveShape(
749           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
750           /*batch_group_count=*/1, args.window, args.dnums,
751           /*preferred_element_type=*/S32))
752   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
753                                inferred_shape));
754 }
755 
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithNarrowerPreferredElementType)756 TEST_F(ShapeInferenceTest,
757        FloatingPointConvolveWithNarrowerPreferredElementType) {
758   ConvolveArgs args = MakeConvolveArgs(F32, F32);
759   TF_ASSERT_OK_AND_ASSIGN(
760       Shape inferred_shape,
761       ShapeInference::InferConvolveShape(
762           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
763           /*batch_group_count=*/1, args.window, args.dnums,
764           /*preferred_element_type=*/BF16))
765   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
766                                inferred_shape));
767 }
768 
TEST_F(ShapeInferenceTest,FloatingPointConvolveWithIntegralPreferredElementType)769 TEST_F(ShapeInferenceTest,
770        FloatingPointConvolveWithIntegralPreferredElementType) {
771   ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
772   TF_ASSERT_OK_AND_ASSIGN(
773       Shape inferred_shape,
774       ShapeInference::InferConvolveShape(
775           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
776           /*batch_group_count=*/1, args.window, args.dnums,
777           /*preferred_element_type=*/S32));
778   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
779                                inferred_shape));
780 }
781 
TEST_F(ShapeInferenceTest,IntegralConvolveWithFloatingPointPreferredElementType)782 TEST_F(ShapeInferenceTest,
783        IntegralConvolveWithFloatingPointPreferredElementType) {
784   ConvolveArgs args = MakeConvolveArgs(S8, S16);
785   TF_ASSERT_OK_AND_ASSIGN(
786       Shape inferred_shape,
787       ShapeInference::InferConvolveShape(
788           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
789           /*batch_group_count=*/1, args.window, args.dnums,
790           /*preferred_element_type=*/F32));
791   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
792                                inferred_shape));
793 }
794 
TEST_F(ShapeInferenceTest,ConvolveWithPreferredElementTypeWithDifferentSignedness)795 TEST_F(ShapeInferenceTest,
796        ConvolveWithPreferredElementTypeWithDifferentSignedness) {
797   ConvolveArgs args = MakeConvolveArgs(S8, S16);
798   TF_ASSERT_OK_AND_ASSIGN(
799       Shape inferred_shape,
800       ShapeInference::InferConvolveShape(
801           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
802           /*batch_group_count=*/1, args.window, args.dnums,
803           /*preferred_element_type=*/U32));
804   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {10, 12, 2, 3}),
805                                inferred_shape));
806 }
807 
TEST_F(ShapeInferenceTest,ConvolveWithNarrowerPreferredElementType)808 TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
809   ConvolveArgs args = MakeConvolveArgs(S8, S16);
810   auto inferred_status =
811       ShapeInference::InferConvolveShape(
812           args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
813           /*batch_group_count=*/1, args.window, args.dnums,
814           /*preferred_element_type=*/S8)
815           .status();
816   ASSERT_FALSE(inferred_status.ok());
817   ASSERT_THAT(inferred_status.error_message(),
818               HasSubstr("must not be narrower than the original type"));
819 }
820 
821 namespace fft {
822 
823 static const char* unsupported_rank = "only supports ranks 1-3";
824 static const char* invalid_rank = "requires input of at least same rank";
825 static const char* requires_complex_input = "requires complex input type";
826 static const char* requires_f32_input = "requires F32 or F64 input type";
827 static const char* dimensions_match = "innermost dimensions match fft_length";
828 static const char* innermost_dimension_matches =
829     "innermost dimension matches fft_length/2+1";
830 
Pass(const Shape & shape,FftType type,absl::Span<const int64> length,const Shape & expected_shape)831 static void Pass(const Shape& shape, FftType type,
832                  absl::Span<const int64> length, const Shape& expected_shape) {
833   auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
834   ASSERT_IS_OK(inferred_status.status());
835   Shape inferred_shape = inferred_status.ValueOrDie();
836   ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
837 }
838 
Fail(const Shape & shape,FftType type,absl::Span<const int64> length,absl::string_view message)839 static void Fail(const Shape& shape, FftType type,
840                  absl::Span<const int64> length, absl::string_view message) {
841   auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
842   ASSERT_FALSE(inferred_status.ok());
843   ASSERT_THAT(inferred_status.status().error_message(),
844               HasSubstr(std::string(message)));
845 }
846 
847 }  // namespace fft
848 
TEST_F(ShapeInferenceTest,InferFftShapeTestFftRanks)849 TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
850   FftType type = FftType::FFT;
851   Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
852   fft::Fail(shape, type, {}, fft::unsupported_rank);
853   fft::Pass(shape, type, {8}, shape);
854   fft::Pass(shape, type, {16, 8}, shape);
855   fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
856   fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
857 }
858 
TEST_F(ShapeInferenceTest,InferFftShapeTestFftTypes)859 TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
860   FftType type = FftType::FFT;
861   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
862   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
863   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
864   fft::Pass(shape_c128, type, {16, 8}, shape_c128);
865 }
866 
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftRanks)867 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
868   FftType type = FftType::IFFT;
869   Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
870   fft::Fail(shape, type, {}, fft::unsupported_rank);
871   fft::Pass(shape, type, {8}, shape);
872   fft::Pass(shape, type, {16, 8}, shape);
873   fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
874   fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
875 }
876 
TEST_F(ShapeInferenceTest,InferFftShapeTestIfftTypes)877 TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
878   FftType type = FftType::IFFT;
879   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
880   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
881   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
882   fft::Pass(shape_c128, type, {16, 8}, shape_c128);
883 }
884 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftRanks)885 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
886   FftType type = FftType::RFFT;
887   Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
888   Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
889   fft::Fail(shape_in, type, {}, fft::unsupported_rank);
890   fft::Pass(shape_in, type, {8}, shape_out);
891   fft::Pass(shape_in, type, {16, 8}, shape_out);
892   fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
893   fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
894 }
895 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftDimensions)896 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
897   FftType type = FftType::RFFT;
898   Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
899   fft::Fail(shape, type, {4}, fft::dimensions_match);
900   fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
901   fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
902   fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
903 
904   Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
905   Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
906   fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
907   fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
908 
909   Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
910   Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
911   Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
912   fft::Pass(even_shape_in, type, {16, 8}, shape_out);
913   fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
914 }
915 
TEST_F(ShapeInferenceTest,InferFftShapeTestRfftTypes)916 TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
917   FftType type = FftType::RFFT;
918   Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
919   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
920   fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
921   fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
922 }
923 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftRanks)924 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
925   FftType type = FftType::IRFFT;
926   Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
927   Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
928   fft::Fail(shape_in, type, {}, fft::unsupported_rank);
929   fft::Pass(shape_in, type, {8}, shape_out);
930   fft::Pass(shape_in, type, {16, 8}, shape_out);
931   fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
932   fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
933 }
934 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftDimensions)935 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
936   FftType type = FftType::IRFFT;
937   Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
938   fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
939   fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
940   fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
941   fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
942 
943   Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
944   Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
945   fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
946   fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
947 
948   Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
949   Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
950   fft::Pass(shape, type, {16, 8}, even_shape_out);
951   fft::Pass(shape, type, {16, 9}, odd_shape_out);
952 }
953 
TEST_F(ShapeInferenceTest,InferFftShapeTestIrfftTypes)954 TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
955   FftType type = FftType::IRFFT;
956   Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
957   Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
958   Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
959   fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
960   fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
961 }
962 
TEST_F(ShapeInferenceTest,MapThatChangesElementType)963 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
964   Shape arg = ShapeUtil::MakeShape(F32, {20});
965   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
966   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
967   EXPECT_IS_OK(inferred_status.status());
968   Shape expected = ShapeUtil::MakeShape(S32, {20});
969   EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
970 }
971 
TEST_F(ShapeInferenceTest,Map)972 TEST_F(ShapeInferenceTest, Map) {
973   auto inferred_status_r1f32 = ShapeInference::InferMapShape(
974       {&vector_32_, &vector_32_},
975       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
976   EXPECT_IS_OK(inferred_status_r1f32.status());
977   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
978 
979   // It's OK to provide a single argument, as long as the applied arity matches
980   // (this degenerates to a Map).
981   auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
982       {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
983   EXPECT_IS_OK(inferred_status_r1f32_one.status());
984   EXPECT_TRUE(
985       ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
986 
987   auto inferred_status_r2s32 = ShapeInference::InferMapShape(
988       {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
989       ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
990   EXPECT_IS_OK(inferred_status_r2s32.status());
991   EXPECT_TRUE(
992       ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
993 
994   auto no_args_error = ShapeInference::InferMapShape(
995       {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
996   ASSERT_FALSE(no_args_error.ok());
997   ASSERT_THAT(no_args_error.status().error_message(),
998               HasSubstr("expects at least one argument"));
999 
1000   auto args_diff_shapes_error = ShapeInference::InferMapShape(
1001       {&vector_32_, &vector_64_},
1002       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1003   ASSERT_FALSE(args_diff_shapes_error.ok());
1004   ASSERT_THAT(args_diff_shapes_error.status().error_message(),
1005               HasSubstr("requires all operands to have the same shape"));
1006 
1007   auto arity_error = ShapeInference::InferMapShape(
1008       {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
1009       {0});
1010   ASSERT_FALSE(arity_error.ok());
1011   ASSERT_THAT(arity_error.status().error_message(),
1012               HasSubstr("function arity must match"));
1013 
1014   auto output_shape_error = ShapeInference::InferMapShape(
1015       {&vector_32_, &vector_32_},
1016       ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
1017   ASSERT_FALSE(output_shape_error.ok());
1018   ASSERT_THAT(output_shape_error.status().error_message(),
1019               HasSubstr("result has to be a scalar"));
1020 
1021   auto param_shape_error = ShapeInference::InferMapShape(
1022       {&vector_32_, &vector_32_},
1023       ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
1024   ASSERT_FALSE(param_shape_error.ok());
1025   ASSERT_THAT(param_shape_error.status().error_message(),
1026               HasSubstr("parameter has to be a scalar"));
1027 
1028   auto param_element_type_error = ShapeInference::InferMapShape(
1029       {&vector_32_, &vector_32_},
1030       ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
1031   ASSERT_FALSE(param_element_type_error.ok());
1032   ASSERT_THAT(param_element_type_error.status().error_message(),
1033               HasSubstr("parameter type has to match argument"));
1034 
1035   Shape arg = ShapeUtil::MakeShape(F32, {20});
1036   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
1037   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
1038   EXPECT_IS_OK(inferred_status.status());
1039   EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
1040 
1041   auto inferred_status_error1 = ShapeInference::InferMapShape(
1042       {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
1043   ASSERT_FALSE(inferred_status_error1.ok());
1044   ASSERT_THAT(inferred_status_error1.status().error_message(),
1045               HasSubstr("arity must match number of arguments"));
1046 
1047   auto inferred_status_error2 = ShapeInference::InferMapShape(
1048       {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
1049   ASSERT_FALSE(inferred_status_error2.ok());
1050   ASSERT_THAT(inferred_status_error2.status().error_message(),
1051               HasSubstr("has to be a scalar"));
1052 
1053   auto inferred_status_error3 = ShapeInference::InferMapShape(
1054       {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
1055   ASSERT_FALSE(inferred_status_error3.ok());
1056   ASSERT_THAT(inferred_status_error3.status().error_message(),
1057               HasSubstr("has to be a scalar"));
1058 
1059   auto inferred_status_error5 = ShapeInference::InferMapShape(
1060       {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
1061   ASSERT_FALSE(inferred_status_error5.ok());
1062   ASSERT_THAT(inferred_status_error5.status().error_message(),
1063               HasSubstr("parameter type has to match argument"));
1064 }
1065 
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)1066 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
1067   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
1068                             /*dimensions_to_reduce=*/{0});
1069 }
1070 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)1071 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
1072   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
1073                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1074                             /*dimensions_to_reduce=*/{0});
1075 }
1076 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)1077 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
1078   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
1079                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1080                             /*dimensions_to_reduce=*/{1});
1081 }
1082 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)1083 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
1084   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
1085                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1086                             /*dimensions_to_reduce=*/{0, 1});
1087 }
1088 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)1089 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
1090   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
1091                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1092                             /*dimensions_to_reduce=*/{1, 2});
1093 }
1094 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)1095 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
1096   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1097                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1098                             /*dimensions_to_reduce=*/{0, 2});
1099 
1100   // Check that the order of dimensions_to_reduce doesn't matter.
1101   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
1102                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
1103                             /*dimensions_to_reduce=*/{2, 0});
1104 }
1105 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)1106 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
1107   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
1108                             /*dimensions_to_reduce=*/{0, 1, 2});
1109 }
1110 
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)1111 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
1112   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1113   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1114   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1115       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1116   auto inferred_status = ShapeInference::InferReduceShape(
1117       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1118   EXPECT_IS_OK(inferred_status.status());
1119   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
1120                                inferred_status.ValueOrDie()));
1121 }
1122 
TEST_F(ReduceShapeInferenceTest,ReduceWindowMultiOutput)1123 TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
1124   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1125   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1126   std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1127   std::vector<const Shape*> inits = {&f32_, &s32_};
1128   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1129       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1130   std::vector<int64> window_dimensions = {1, 2, 4};
1131   std::vector<int64> window_strides = {1, 1, 1};
1132   std::vector<std::pair<int64, int64>> padding_values =
1133       MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1134                   window_strides, Padding::kValid);
1135   TF_ASSERT_OK_AND_ASSIGN(
1136       Window window,
1137       ShapeInference::InferWindowFromDimensions(
1138           window_dimensions, window_strides, padding_values, {}, {}));
1139   auto inferred_status = ShapeInference::InferReduceWindowShape(
1140       absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1141   VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
1142   EXPECT_IS_OK(inferred_status.status());
1143   EXPECT_TRUE(ShapeUtil::Equal(
1144       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
1145                                  ShapeUtil::MakeShape(S32, {5, 2, 0})}),
1146       inferred_status.ValueOrDie()));
1147 }
1148 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)1149 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
1150   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1151   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1152   ProgramShape to_apply =
1153       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
1154                                   ShapeUtil::MakeTupleShape({f32_, s32_}));
1155   auto inferred_status = ShapeInference::InferReduceShape(
1156       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1157   EXPECT_FALSE(inferred_status.ok());
1158   EXPECT_THAT(inferred_status.status().error_message(),
1159               HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
1160 }
1161 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)1162 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
1163   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1164   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1165   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1166       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1167   auto inferred_status = ShapeInference::InferReduceShape(
1168       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1169   EXPECT_FALSE(inferred_status.ok());
1170   EXPECT_THAT(
1171       inferred_status.status().error_message(),
1172       HasSubstr(
1173           "parameter shape differs from the result shape: s32[] vs f32[]"));
1174 }
1175 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)1176 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
1177   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1178       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1179   auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
1180   EXPECT_FALSE(inferred_status.ok());
1181   EXPECT_THAT(inferred_status.status().error_message(),
1182               HasSubstr("must have at least 2 arguments, has 0"));
1183 }
1184 
TEST_F(ReduceShapeInferenceTest,ErrorBadReduceWindowInput)1185 TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
1186   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
1187   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
1188   std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
1189   std::vector<const Shape*> inits = {&f32_, &s32_};
1190   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1191       {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
1192   std::vector<int64> window_dimensions = {1, 2, 4};
1193   std::vector<int64> window_strides = {1, 1, 1};
1194   std::vector<std::pair<int64, int64>> padding_values =
1195       MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
1196                   window_strides, Padding::kValid);
1197   TF_ASSERT_OK_AND_ASSIGN(
1198       Window window,
1199       ShapeInference::InferWindowFromDimensions(
1200           window_dimensions, window_strides, padding_values, {}, {}));
1201   auto inferred_status = ShapeInference::InferReduceWindowShape(
1202       absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
1203   EXPECT_FALSE(inferred_status.status().ok());
1204   EXPECT_THAT(inferred_status.status().error_message(),
1205               HasSubstr("f32[] vs s32[]"));
1206 }
1207 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)1208 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
1209   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1210   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1211   ProgramShape to_apply =
1212       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
1213   auto inferred_status = ShapeInference::InferReduceShape(
1214       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1215   EXPECT_FALSE(inferred_status.ok());
1216   EXPECT_THAT(
1217       inferred_status.status().error_message(),
1218       HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
1219 }
1220 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)1221 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
1222   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1223   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1224   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1225       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
1226   auto inferred_status = ShapeInference::InferReduceShape(
1227       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1228   EXPECT_FALSE(inferred_status.ok());
1229   EXPECT_THAT(
1230       inferred_status.status().error_message(),
1231       HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
1232 }
1233 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)1234 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
1235   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1236   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
1237   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
1238       {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
1239   auto inferred_status = ShapeInference::InferReduceShape(
1240       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
1241   EXPECT_FALSE(inferred_status.ok());
1242   EXPECT_THAT(inferred_status.status().error_message(),
1243               HasSubstr("accumulator shape at index 0 differs from the "
1244                         "init_value shape: s32[] vs f32[]"));
1245 }
1246 
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)1247 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
1248   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1249   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1250   auto inferred_status = ShapeInference::InferReduceShape(
1251       {&arg_shape, &f32_},
1252       /*dimensions_to_reduce=*/{3, 4}, to_apply);
1253   EXPECT_FALSE(inferred_status.ok());
1254   EXPECT_THAT(inferred_status.status().error_message(),
1255               HasSubstr("out-of-bounds dimension"));
1256 }
1257 
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)1258 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
1259   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
1260   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1261   auto inferred_status =
1262       ShapeInference::InferReduceShape({&arg_shape, &f32_},
1263                                        /*dimensions_to_reduce=*/{0}, to_apply);
1264   EXPECT_FALSE(inferred_status.ok());
1265   EXPECT_THAT(inferred_status.status().error_message(),
1266               HasSubstr("take 2 parameters"));
1267 }
1268 
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)1269 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
1270   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
1271   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1272   auto inferred_status =
1273       ShapeInference::InferReduceShape({&arg_shape, &f32_},
1274                                        /*dimensions_to_reduce=*/{0}, to_apply);
1275   EXPECT_FALSE(inferred_status.ok());
1276   EXPECT_THAT(inferred_status.status().error_message(),
1277               HasSubstr("0-th parameter shape differs"));
1278 }
1279 
TEST_F(ReduceShapeInferenceTest,ReduceWithRepeatedReduceDimension)1280 TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
1281   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1282   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1283   auto inferred_status = ShapeInference::InferReduceShape(
1284       {&arg_shape, &f32_},
1285       /*dimensions_to_reduce=*/{0, 0}, to_apply);
1286   EXPECT_FALSE(inferred_status.ok());
1287   EXPECT_THAT(inferred_status.status().error_message(),
1288               HasSubstr("Duplicate reduction dimension: 0"));
1289 }
1290 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)1291 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
1292   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1293   auto inferred_status =
1294       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
1295   ASSERT_IS_OK(inferred_status.status());
1296   Shape inferred = inferred_status.ValueOrDie();
1297   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
1298 }
1299 
TEST_F(ShapeInferenceTest,InferSliceWithDynamicDimensions)1300 TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
1301   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
1302   auto inferred_status =
1303       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
1304   ASSERT_IS_OK(inferred_status.status());
1305   Shape inferred = inferred_status.ValueOrDie();
1306   ASSERT_TRUE(ShapeUtil::Equal(
1307       ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
1308 }
1309 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)1310 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
1311   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1312   auto inferred_status =
1313       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
1314   ASSERT_IS_OK(inferred_status.status());
1315   Shape inferred = inferred_status.ValueOrDie();
1316   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
1317 }
1318 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)1319 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
1320   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1321   auto inferred_status =
1322       ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
1323   ASSERT_IS_OK(inferred_status.status());
1324   Shape inferred = inferred_status.ValueOrDie();
1325   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
1326 }
1327 
TEST_F(ShapeInferenceTest,InferInvalidStride)1328 TEST_F(ShapeInferenceTest, InferInvalidStride) {
1329   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1330   auto inferred_status =
1331       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
1332   ASSERT_FALSE(inferred_status.ok());
1333   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1334             inferred_status.status().code());
1335 }
1336 
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)1337 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
1338   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
1339   auto inferred_status =
1340       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
1341   ASSERT_FALSE(inferred_status.ok());
1342   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
1343             inferred_status.status().code());
1344 }
1345 
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)1346 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
1347   Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
1348   auto inferred_status =
1349       ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
1350   ASSERT_TRUE(inferred_status.ok());
1351   Shape inferred = inferred_status.ValueOrDie();
1352   ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
1353 }
1354 
TEST_F(ShapeInferenceTest,InferConstIndexShape)1355 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
1356   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1357   auto inferred0_status =
1358       ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
1359   auto inferred1_status =
1360       ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
1361   ASSERT_IS_OK(inferred0_status.status());
1362   ASSERT_IS_OK(inferred1_status.status());
1363   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
1364   ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
1365 }
1366 
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)1367 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
1368   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
1369   auto inferredNegative_status =
1370       ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
1371   auto inferred2_status =
1372       ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
1373   ASSERT_FALSE(inferredNegative_status.ok());
1374   ASSERT_FALSE(inferred2_status.ok());
1375   EXPECT_THAT(inferredNegative_status.status().error_message(),
1376               HasSubstr("attempt to index out of tuple bounds"));
1377   EXPECT_THAT(inferred2_status.status().error_message(),
1378               HasSubstr("attempt to index out of tuple bounds"));
1379 }
1380 
TEST_F(ShapeInferenceTest,InferPowShape)1381 TEST_F(ShapeInferenceTest, InferPowShape) {
1382   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1383   auto inferred_status = ShapeInference::InferBinaryOpShape(
1384       HloOpcode::kPower, ten_floats, f32_, {});
1385   ASSERT_IS_OK(inferred_status.status());
1386   ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
1387 }
1388 
TEST_F(ShapeInferenceTest,InferCompareShape)1389 TEST_F(ShapeInferenceTest, InferCompareShape) {
1390   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
1391   auto inferred_status = ShapeInference::InferBinaryOpShape(
1392       HloOpcode::kCompare, ten_floats, f32_, {});
1393   ASSERT_IS_OK(inferred_status.status());
1394   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
1395                                inferred_status.ValueOrDie()));
1396 }
1397 
TEST_F(ShapeInferenceTest,InferReshapeDegenerateCombine)1398 TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
1399   // [1, <=1]
1400   //   | reshape
1401   // [<=1]
1402   //
1403   // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1404   auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
1405   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
1406                                                   /*inferred_dimension=*/-1);
1407   ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
1408 }
1409 
TEST_F(ShapeInferenceTest,InferReshapeSplit)1410 TEST_F(ShapeInferenceTest, InferReshapeSplit) {
1411   // [<=10]
1412   //   | reshape
1413   // [1, 10]
1414   //
1415   // Both output dimension can be dynamic, use inferred_dimension to tie-break.
1416   auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
1417   auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
1418                                                   /*inferred_dimension=*/0);
1419   ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
1420             status.ValueOrDie());
1421 }
1422 
TEST_F(ShapeInferenceTest,InferReshapeCombine)1423 TEST_F(ShapeInferenceTest, InferReshapeCombine) {
1424   // [6, <=10]
1425   //   | reshape
1426   // [<=60]
1427   auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1428   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
1429                                                   /*inferred_dimension=*/-11);
1430   ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
1431 }
1432 
TEST_F(ShapeInferenceTest,UnchangedDimension)1433 TEST_F(ShapeInferenceTest, UnchangedDimension) {
1434   // [6, <=10]
1435   //   | reshape
1436   // [2, 3, <=10]
1437   auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
1438   auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
1439                                                   /*inferred_dimension=*/-11);
1440   ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
1441             status.ValueOrDie());
1442 }
1443 
TEST_F(ShapeInferenceTest,InferDynamicBroadcast)1444 TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
1445   // CHECK:
1446   // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
1447 
1448   auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
1449   auto inferred_status =
1450       ShapeInference::InferBroadcastShape(operand_shape, {15});
1451   ASSERT_IS_OK(inferred_status.status());
1452   Shape inferred = inferred_status.ValueOrDie();
1453   ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
1454 }
1455 
TEST_F(ShapeInferenceTest,BroadcastScalar)1456 TEST_F(ShapeInferenceTest, BroadcastScalar) {
1457   for (auto element_type : {F32, U32, S8}) {
1458     const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
1459     {  // no-op scalar broadcast
1460       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
1461       ASSERT_IS_OK(status.status());
1462       ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
1463     }
1464     const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
1465     {  // scalar -> 1d broadcast
1466       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
1467       ASSERT_IS_OK(status.status());
1468       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1469     }
1470     {  // no-op 1d broadcast
1471       auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
1472       ASSERT_IS_OK(status.status());
1473       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
1474     }
1475     const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
1476     {  // scalar -> 2d broadcast
1477       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
1478       ASSERT_IS_OK(status.status());
1479       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1480     }
1481     {  // 1d -> 2d broadcast
1482       auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
1483       ASSERT_IS_OK(status.status());
1484       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
1485     }
1486   }
1487 }
1488 
1489 // scalar <dot> vector: ok
TEST_F(ShapeInferenceTest,ScalarDotVector)1490 TEST_F(ShapeInferenceTest, ScalarDotVector) {
1491   DotDimensionNumbers dot_dnums;
1492   auto inferred_status = ShapeInference::InferDotOpShape(
1493       f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
1494   EXPECT_TRUE(inferred_status.ok());
1495   EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
1496 }
1497 
1498 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)1499 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
1500   DotDimensionNumbers dot_dnums;
1501   dot_dnums.add_lhs_contracting_dimensions(1);
1502   dot_dnums.add_rhs_contracting_dimensions(0);
1503   auto inferred_status = ShapeInference::InferDotOpShape(
1504       ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
1505       /*preferred_element_type=*/absl::nullopt);
1506   EXPECT_TRUE(inferred_status.ok());
1507   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1508                                ShapeUtil::MakeShape(F32, {32, 32, 64})));
1509 }
1510 
1511 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)1512 TEST_F(ShapeInferenceTest, VectorDotVector) {
1513   DotDimensionNumbers dot_dnums;
1514   dot_dnums.add_lhs_contracting_dimensions(0);
1515   dot_dnums.add_rhs_contracting_dimensions(0);
1516   auto inferred_status =
1517       ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
1518                                       /*preferred_element_type=*/absl::nullopt);
1519   ASSERT_IS_OK(inferred_status.status());
1520   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
1521   auto inferred_status_mismatch =
1522       ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
1523                                       /*preferred_element_type=*/absl::nullopt);
1524   ASSERT_FALSE(inferred_status_mismatch.ok());
1525 }
1526 
1527 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1528 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1529   DotDimensionNumbers dot_dnums;
1530   dot_dnums.add_lhs_contracting_dimensions(1);
1531   dot_dnums.add_rhs_contracting_dimensions(0);
1532   auto inferred_status =
1533       ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
1534                                       /*preferred_element_type=*/absl::nullopt);
1535   ASSERT_IS_OK(inferred_status.status());
1536   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1537   auto inferred_status_mismatch =
1538       ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
1539                                       /*preferred_element_type=*/absl::nullopt);
1540   ASSERT_FALSE(inferred_status_mismatch.ok());
1541 }
1542 
1543 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1544 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1545   DotDimensionNumbers dot_dnums;
1546   dot_dnums.add_lhs_contracting_dimensions(0);
1547   dot_dnums.add_rhs_contracting_dimensions(0);
1548   auto inferred_status =
1549       ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
1550                                       /*preferred_element_type=*/absl::nullopt);
1551   ASSERT_IS_OK(inferred_status.status());
1552   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1553   auto inferred_status_mismatch =
1554       ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
1555                                       /*preferred_element_type=*/absl::nullopt);
1556   ASSERT_FALSE(inferred_status_mismatch.ok());
1557 }
1558 
1559 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1560 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1561   DotDimensionNumbers dot_dnums;
1562   dot_dnums.add_lhs_contracting_dimensions(1);
1563   dot_dnums.add_rhs_contracting_dimensions(0);
1564   auto inferred_status_match =
1565       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
1566                                       /*preferred_element_type=*/absl::nullopt);
1567   ASSERT_IS_OK(inferred_status_match.status());
1568   ASSERT_TRUE(
1569       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1570       << "inferred: "
1571       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1572       << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1573   auto inferred_status_mismatch =
1574       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
1575                                       /*preferred_element_type=*/absl::nullopt);
1576   ASSERT_FALSE(inferred_status_mismatch.ok());
1577 }
1578 
1579 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1580 TEST_F(ShapeInferenceTest, DotGeneral) {
1581   Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1582   Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1583   Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1584 
1585   DotDimensionNumbers dot_dnums;
1586   dot_dnums.add_lhs_contracting_dimensions(3);
1587   dot_dnums.add_lhs_batch_dimensions(0);
1588   dot_dnums.add_lhs_batch_dimensions(1);
1589 
1590   dot_dnums.add_rhs_contracting_dimensions(2);
1591   dot_dnums.add_rhs_batch_dimensions(0);
1592   dot_dnums.add_rhs_batch_dimensions(1);
1593 
1594   auto inferred_status_match =
1595       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1596                                       /*preferred_element_type=*/absl::nullopt);
1597   ASSERT_IS_OK(inferred_status_match.status());
1598   ASSERT_TRUE(
1599       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1600       << "inferred: "
1601       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1602       << " expected: " << ShapeUtil::HumanString(output_shape);
1603 }
1604 
1605 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1606 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1607   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1608   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1609 
1610   DotDimensionNumbers dot_dnums;
1611   dot_dnums.add_lhs_contracting_dimensions(2);
1612   dot_dnums.add_lhs_contracting_dimensions(3);
1613   dot_dnums.add_lhs_batch_dimensions(0);
1614 
1615   dot_dnums.add_rhs_contracting_dimensions(1);
1616   dot_dnums.add_rhs_batch_dimensions(0);
1617 
1618   auto inferred_status =
1619       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1620                                       /*preferred_element_type=*/absl::nullopt);
1621   ASSERT_FALSE(inferred_status.ok());
1622   ASSERT_THAT(inferred_status.status().error_message(),
1623               HasSubstr("Must specify the same number of contracting "
1624                         "dimensions for lhs and rhs."));
1625 }
1626 
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1627 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1628   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1629   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1630   Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1631 
1632   DotDimensionNumbers dot_dnums;
1633   dot_dnums.add_lhs_contracting_dimensions(2);
1634   dot_dnums.add_lhs_contracting_dimensions(3);
1635   dot_dnums.add_lhs_batch_dimensions(0);
1636 
1637   dot_dnums.add_rhs_contracting_dimensions(1);
1638   dot_dnums.add_rhs_contracting_dimensions(2);
1639   dot_dnums.add_rhs_batch_dimensions(0);
1640 
1641   auto inferred_status =
1642       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1643                                       /*preferred_element_type=*/absl::nullopt);
1644   EXPECT_TRUE(inferred_status.ok());
1645   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1646 }
1647 
TEST_F(ShapeInferenceTest,ErrorSetDimensionSize)1648 TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
1649   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1650   Shape val_shape = ShapeUtil::MakeShape(S32, {1});
1651   auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1652       arg_shape, val_shape, /*dimension=*/0);
1653 
1654   EXPECT_FALSE(inferred_status.ok());
1655   EXPECT_THAT(inferred_status.status().error_message(),
1656               HasSubstr("value has to be S32 scalar"));
1657 }
1658 
TEST_F(ShapeInferenceTest,ErrorSetDimensionSizeWrongType)1659 TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
1660   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
1661   Shape val_shape = ShapeUtil::MakeShape(U32, {});
1662   auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
1663       arg_shape, val_shape, /*dimension=*/0);
1664 
1665   EXPECT_FALSE(inferred_status.ok());
1666   EXPECT_THAT(inferred_status.status().error_message(),
1667               HasSubstr("value has to be S32 scalar"));
1668 }
1669 
1670 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimSizesFails)1671 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
1672   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1673   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1674 
1675   DotDimensionNumbers dot_dnums;
1676   dot_dnums.add_lhs_contracting_dimensions(2);
1677   dot_dnums.add_lhs_batch_dimensions(0);
1678 
1679   dot_dnums.add_rhs_contracting_dimensions(1);
1680   dot_dnums.add_rhs_batch_dimensions(0);
1681 
1682   auto inferred_status =
1683       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1684                                       /*preferred_element_type=*/absl::nullopt);
1685   ASSERT_FALSE(inferred_status.ok());
1686   ASSERT_THAT(inferred_status.status().error_message(),
1687               HasSubstr("Batch dimension sizes must match"));
1688 }
1689 
1690 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMismatchedBatchDimNumbersPasses)1691 TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
1692   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1693   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1694 
1695   DotDimensionNumbers dot_dnums;
1696   dot_dnums.add_lhs_contracting_dimensions(2);
1697   dot_dnums.add_lhs_batch_dimensions(0);
1698 
1699   dot_dnums.add_rhs_contracting_dimensions(0);
1700   dot_dnums.add_rhs_batch_dimensions(1);
1701 
1702   auto inferred_status =
1703       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1704                                       /*preferred_element_type=*/absl::nullopt);
1705   ASSERT_TRUE(inferred_status.ok());
1706   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1707                                ShapeUtil::MakeShape(F32, {2, 11, 14})));
1708 }
1709 
1710 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1711 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1712   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1713   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1714 
1715   DotDimensionNumbers dot_dnums;
1716   dot_dnums.add_lhs_contracting_dimensions(3);
1717   dot_dnums.add_lhs_batch_dimensions(0);
1718 
1719   dot_dnums.add_rhs_contracting_dimensions(0);
1720   dot_dnums.add_rhs_batch_dimensions(1);
1721 
1722   auto inferred_status =
1723       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1724                                       /*preferred_element_type=*/absl::nullopt);
1725   ASSERT_FALSE(inferred_status.ok());
1726   ASSERT_THAT(inferred_status.status().error_message(),
1727               HasSubstr("A dimension number is out of range"));
1728 }
1729 
1730 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1731 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1732   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1733   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1734 
1735   DotDimensionNumbers dot_dnums;
1736   dot_dnums.add_lhs_contracting_dimensions(0);
1737   dot_dnums.add_lhs_batch_dimensions(0);
1738 
1739   dot_dnums.add_rhs_contracting_dimensions(0);
1740   dot_dnums.add_rhs_batch_dimensions(1);
1741 
1742   auto inferred_status =
1743       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
1744                                       /*preferred_element_type=*/absl::nullopt);
1745   ASSERT_FALSE(inferred_status.ok());
1746   ASSERT_THAT(inferred_status.status().error_message(),
1747               HasSubstr("A dimension number is not unique"));
1748 }
1749 
TEST_F(ShapeInferenceTest,DotWithIntegralPreferredElementType)1750 TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
1751   DotDimensionNumbers dot_dnums;
1752   dot_dnums.add_lhs_contracting_dimensions(1);
1753   dot_dnums.add_rhs_contracting_dimensions(0);
1754   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1755                           ShapeInference::InferDotOpShape(
1756                               ShapeUtil::MakeShape(S8, {32, 32}),
1757                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1758                               /*preferred_element_type=*/S32));
1759   EXPECT_TRUE(
1760       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1761 }
1762 
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeSameAsInferredType)1763 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
1764   DotDimensionNumbers dot_dnums;
1765   dot_dnums.add_lhs_contracting_dimensions(1);
1766   dot_dnums.add_rhs_contracting_dimensions(0);
1767   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1768                           ShapeInference::InferDotOpShape(
1769                               ShapeUtil::MakeShape(BF16, {32, 32}),
1770                               ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1771                               /*preferred_element_type=*/F32));
1772   EXPECT_TRUE(
1773       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1774 }
1775 
TEST_F(ShapeInferenceTest,FloatingPointDotWithNarrowerPreferredElementType)1776 TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
1777   DotDimensionNumbers dot_dnums;
1778   dot_dnums.add_lhs_contracting_dimensions(1);
1779   dot_dnums.add_rhs_contracting_dimensions(0);
1780   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1781                           ShapeInference::InferDotOpShape(
1782                               ShapeUtil::MakeShape(BF16, {32, 32}),
1783                               ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
1784                               /*preferred_element_type=*/BF16));
1785   EXPECT_TRUE(
1786       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
1787 }
1788 
TEST_F(ShapeInferenceTest,FloatingPointDotWithIntegralPreferredElementType)1789 TEST_F(ShapeInferenceTest, FloatingPointDotWithIntegralPreferredElementType) {
1790   DotDimensionNumbers dot_dnums;
1791   dot_dnums.add_lhs_contracting_dimensions(1);
1792   dot_dnums.add_rhs_contracting_dimensions(0);
1793   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1794                           ShapeInference::InferDotOpShape(
1795                               ShapeUtil::MakeShape(BF16, {32, 32}),
1796                               ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
1797                               /*preferred_element_type=*/S32));
1798   EXPECT_TRUE(
1799       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
1800 }
1801 
TEST_F(ShapeInferenceTest,IntegralDotWithFloatingPointPreferredElementType)1802 TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
1803   DotDimensionNumbers dot_dnums;
1804   dot_dnums.add_lhs_contracting_dimensions(1);
1805   dot_dnums.add_rhs_contracting_dimensions(0);
1806   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1807                           ShapeInference::InferDotOpShape(
1808                               ShapeUtil::MakeShape(S8, {32, 32}),
1809                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1810                               /*preferred_element_type=*/F32));
1811   EXPECT_TRUE(
1812       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
1813 }
1814 
TEST_F(ShapeInferenceTest,DotWithPreferredElementTypeWithDifferentSignedness)1815 TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
1816   DotDimensionNumbers dot_dnums;
1817   dot_dnums.add_lhs_contracting_dimensions(1);
1818   dot_dnums.add_rhs_contracting_dimensions(0);
1819   TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
1820                           ShapeInference::InferDotOpShape(
1821                               ShapeUtil::MakeShape(S8, {32, 32}),
1822                               ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1823                               /*preferred_element_type=*/U32));
1824   EXPECT_TRUE(
1825       ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(U32, {32, 32})));
1826 }
1827 
TEST_F(ShapeInferenceTest,DotWithNarrowerPreferredElementType)1828 TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
1829   DotDimensionNumbers dot_dnums;
1830   dot_dnums.add_lhs_contracting_dimensions(1);
1831   dot_dnums.add_rhs_contracting_dimensions(0);
1832   auto inferred_status = ShapeInference::InferDotOpShape(
1833                              ShapeUtil::MakeShape(S8, {32, 32}),
1834                              ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
1835                              /*preferred_element_type=*/S8)
1836                              .status();
1837   ASSERT_FALSE(inferred_status.ok());
1838   ASSERT_THAT(inferred_status.error_message(),
1839               HasSubstr("must not be narrower than the original type"));
1840 }
1841 
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1842 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1843   // Test variations of broadcasting a vector for a binary add with a
1844   // matrix.
1845   const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1846   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1847   const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1848 
1849   auto inferred_status_match =
1850       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1851   ASSERT_IS_OK(inferred_status_match.status());
1852   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1853 
1854   auto inferred_status_mismatch =
1855       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1856   ASSERT_FALSE(inferred_status_mismatch.ok());
1857 
1858   inferred_status_match =
1859       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1860   ASSERT_IS_OK(inferred_status_match.status());
1861   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1862 
1863   inferred_status_mismatch =
1864       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1865   ASSERT_FALSE(inferred_status_mismatch.ok());
1866 }
1867 
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1868 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1869   // Test variations of broadcasting a matrix for a binary add with a cube.
1870   const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1871   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1872   const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1873   const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1874 
1875   auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1876       HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1877   ASSERT_IS_OK(inferred_status_match.status());
1878   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1879 
1880   inferred_status_match = ShapeInference::InferBinaryOpShape(
1881       HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1882   ASSERT_IS_OK(inferred_status_match.status());
1883   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1884 
1885   inferred_status_match = ShapeInference::InferBinaryOpShape(
1886       HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1887   ASSERT_IS_OK(inferred_status_match.status());
1888   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1889 }
1890 
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1891 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1892   // Test various errors with the broadcast argument.
1893   const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1894   const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1895   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1896   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1897   const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1898 
1899   // "magical" broadcast rejected
1900   auto inferred_status_error1 =
1901       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1902   ASSERT_FALSE(inferred_status_error1.ok());
1903   ASSERT_THAT(inferred_status_error1.status().error_message(),
1904               HasSubstr("Automatic"));
1905 
1906   // broadcast_dimension out of bounds for tensor's rank
1907   auto inferred_status_error2 =
1908       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1909   ASSERT_FALSE(inferred_status_error2.ok());
1910   ASSERT_THAT(inferred_status_error2.status().error_message(),
1911               ContainsRegex("Broadcast dimension number .* too large"));
1912 
1913   // broadcast_dimension doesn't match corresponding dimension
1914   auto inferred_status_error3 =
1915       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1916   ASSERT_FALSE(inferred_status_error3.ok());
1917   ASSERT_THAT(inferred_status_error3.status().error_message(),
1918               HasSubstr("Broadcast dimension 0 mismatch"));
1919 
1920   // broadcast_dimensions list too long
1921   auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1922       HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1923   ASSERT_FALSE(inferred_status_error4.ok());
1924   ASSERT_THAT(inferred_status_error4.status().error_message(),
1925               HasSubstr("broadcast_dimensions has to match"));
1926 
1927   // there's a dimension above the rank of the tensor
1928   auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1929       HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1930   ASSERT_FALSE(inferred_status_error5.ok());
1931   ASSERT_THAT(inferred_status_error5.status().error_message(),
1932               ContainsRegex("dimension number .* too large"));
1933 
1934   // broadcasting dimensions don't match in this order
1935   auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1936       HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1937   ASSERT_FALSE(inferred_status_error6.ok());
1938   ASSERT_THAT(inferred_status_error6.status().error_message(),
1939               HasSubstr("dimension 0 mismatch"));
1940 
1941   // The following two tests make sure that broadcasting dimensions are listed
1942   // in a proper (strictly increasing) order, even if the lower-rank array
1943   // matches the higher-rank array in many different ways.
1944   auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1945       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1946   ASSERT_FALSE(inferred_status_error7.ok());
1947   ASSERT_THAT(inferred_status_error7.status().error_message(),
1948               HasSubstr("dimensions order is wrong"));
1949 
1950   auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1951       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1952   ASSERT_FALSE(inferred_status_error8.ok());
1953   ASSERT_THAT(inferred_status_error8.status().error_message(),
1954               HasSubstr("dimensions order is wrong"));
1955 }
1956 
1957 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1958 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1959   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1960   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1961   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1962   auto inferred_status =
1963       ShapeInference::InferWhileShape(cond, body, result_shape);
1964   ASSERT_IS_OK(inferred_status.status());
1965   Shape inferred = inferred_status.ValueOrDie();
1966   ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1967 }
1968 
1969 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1970 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1971   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1972   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1973   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1974 
1975   auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1976   auto inferred_status_error1 =
1977       ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1978   ASSERT_FALSE(inferred_status_error1.ok());
1979   ASSERT_THAT(inferred_status_error1.status().error_message(),
1980               HasSubstr("Condition must take 1 arguments"));
1981 
1982   auto bad_shape_2 =
1983       ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1984   auto inferred_status_error2 =
1985       ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
1986   ASSERT_FALSE(inferred_status_error2.ok());
1987   ASSERT_THAT(inferred_status_error2.status().error_message(),
1988               HasSubstr("Body must take 1 arguments"));
1989 
1990   auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
1991   auto inferred_status_error3 =
1992       ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
1993   ASSERT_FALSE(inferred_status_error3.ok());
1994   ASSERT_THAT(inferred_status_error3.status().error_message(),
1995               HasSubstr("Condition must return a boolean"));
1996 
1997   auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
1998   auto inferred_status_error4 =
1999       ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
2000   ASSERT_FALSE(inferred_status_error4.ok());
2001   ASSERT_THAT(inferred_status_error4.status().error_message(),
2002               HasSubstr("parameter of condition and body"));
2003 }
2004 
2005 // Tests for the concatenate instruction with dynamic shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithDynamicShapes)2006 TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
2007   auto dynamic_shape_1 =
2008       ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
2009   auto dynamic_shape_2 =
2010       ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
2011   auto inferred_status = ShapeInference::InferConcatOpShape(
2012       {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
2013   ASSERT_IS_OK(inferred_status.status());
2014   Shape inferred = inferred_status.ValueOrDie();
2015   ASSERT_TRUE(ShapeUtil::Equal(
2016       ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
2017 }
2018 
2019 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)2020 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
2021   auto inferred_status_1 = ShapeInference::InferConcatOpShape(
2022       {&vector_32_, &vector_64_}, /*dimension=*/0);
2023   ASSERT_IS_OK(inferred_status_1.status());
2024   Shape inferred_1 = inferred_status_1.ValueOrDie();
2025   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
2026 
2027   auto inferred_status_2 = ShapeInference::InferConcatOpShape(
2028       {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
2029   ASSERT_IS_OK(inferred_status_2.status());
2030   Shape inferred_2 = inferred_status_2.ValueOrDie();
2031   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
2032 
2033   auto inferred_status_3 = ShapeInference::InferConcatOpShape(
2034       {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
2035   ASSERT_IS_OK(inferred_status_3.status());
2036   Shape inferred_3 = inferred_status_3.ValueOrDie();
2037   ASSERT_TRUE(
2038       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
2039 }
2040 
2041 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)2042 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
2043   auto inferred_status_error1 =
2044       ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
2045   ASSERT_FALSE(inferred_status_error1.ok());
2046   ASSERT_THAT(inferred_status_error1.status().error_message(),
2047               HasSubstr("Concatenate expects at least one argument"));
2048 
2049   auto inferred_status_error2 =
2050       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
2051   ASSERT_FALSE(inferred_status_error2.ok());
2052   ASSERT_THAT(inferred_status_error2.status().error_message(),
2053               HasSubstr("dimension out of bounds: -1"));
2054 
2055   auto inferred_status_error3 =
2056       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
2057   ASSERT_FALSE(inferred_status_error3.ok());
2058   ASSERT_THAT(inferred_status_error3.status().error_message(),
2059               HasSubstr("dimension out of bounds: 1"));
2060 
2061   Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
2062   auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
2063       {&vector_32_, &tuple}, /*dimension=*/0);
2064   ASSERT_FALSE(inferred_status_error4.ok());
2065   ASSERT_THAT(
2066       inferred_status_error4.status().error_message(),
2067       HasSubstr("Expected array argument for operand of concatenation"));
2068 
2069   const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
2070   auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
2071       {&vector_32_, &vector_s32}, /*dimension=*/0);
2072   ASSERT_FALSE(inferred_status_error5.ok());
2073   ASSERT_THAT(inferred_status_error5.status().error_message(),
2074               HasSubstr("concatenate arrays with different element types"));
2075 
2076   auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
2077       {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
2078   ASSERT_FALSE(inferred_status_error6.ok());
2079   ASSERT_THAT(inferred_status_error6.status().error_message(),
2080               HasSubstr("concatenate arrays that differ in "
2081                         "dimensions other than the one being "
2082                         "concatenated"));
2083 }
2084 
TEST_F(ShapeInferenceTest,Pad)2085 TEST_F(ShapeInferenceTest, Pad) {
2086   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2087   Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
2088   // Padding for dimension 0: {low: 0, high: 2, interior: 3}
2089   // Padding for dimension 1: {low: 1, high: 5, interior: 0}
2090   PaddingConfig padding_config;
2091   auto dimension0 = padding_config.add_dimensions();
2092   dimension0->set_edge_padding_low(0);
2093   dimension0->set_edge_padding_high(2);
2094   dimension0->set_interior_padding(3);
2095   auto dimension1 = padding_config.add_dimensions();
2096   dimension1->set_edge_padding_low(1);
2097   dimension1->set_edge_padding_high(5);
2098   dimension1->set_interior_padding(0);
2099 
2100   auto inferred_status = ShapeInference::InferPadShape(
2101       input_shape, padding_value_shape, padding_config);
2102   ASSERT_IS_OK(inferred_status.status());
2103   Shape inferred_shape = inferred_status.ValueOrDie();
2104   ASSERT_TRUE(
2105       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
2106 
2107   dimension1->set_edge_padding_low(-20);
2108   dimension1->set_edge_padding_high(-10);
2109   auto negative_dimension_size = ShapeInference::InferPadShape(
2110       input_shape, padding_value_shape, padding_config);
2111   ASSERT_FALSE(negative_dimension_size.ok());
2112   ASSERT_THAT(negative_dimension_size.status().error_message(),
2113               HasSubstr("negative size for dimension 1"));
2114 }
2115 
TEST_F(ShapeInferenceTest,Reverse)2116 TEST_F(ShapeInferenceTest, Reverse) {
2117   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2118 
2119   auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
2120   ASSERT_IS_OK(inferred_status.status());
2121   Shape inferred_shape = inferred_status.ValueOrDie();
2122   ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
2123 }
2124 
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)2125 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
2126   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
2127 
2128   auto inferred_status_error0 =
2129       ShapeInference::InferReverseShape(input_shape, {0, 2});
2130   ASSERT_FALSE(inferred_status_error0.ok());
2131   ASSERT_THAT(inferred_status_error0.status().error_message(),
2132               HasSubstr("out-of-bounds"));
2133 
2134   auto inferred_status_error1 =
2135       ShapeInference::InferReverseShape(input_shape, {0, -1});
2136   ASSERT_FALSE(inferred_status_error1.ok());
2137   ASSERT_THAT(inferred_status_error1.status().error_message(),
2138               HasSubstr("out-of-bounds"));
2139 
2140   auto inferred_status_error2 =
2141       ShapeInference::InferReverseShape(input_shape, {0, 0});
2142   ASSERT_FALSE(inferred_status_error2.ok());
2143   ASSERT_THAT(inferred_status_error2.status().error_message(),
2144               HasSubstr("duplicated"));
2145 
2146   Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
2147   auto inferred_status_error3 =
2148       ShapeInference::InferReverseShape(tuple_shape, {0});
2149   ASSERT_FALSE(inferred_status_error3.ok());
2150   ASSERT_THAT(inferred_status_error3.status().error_message(),
2151               HasSubstr("Expected array argument"));
2152 }
2153 
TEST_F(ShapeInferenceTest,Call)2154 TEST_F(ShapeInferenceTest, Call) {
2155   auto inferred_status0 =
2156       ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
2157   EXPECT_IS_OK(inferred_status0.status());
2158   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2159 
2160   auto inferred_status1 = ShapeInference::InferCallShape(
2161       {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
2162       ShapeUtil::MakeProgramShape(
2163           {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
2164   EXPECT_IS_OK(inferred_status1.status());
2165   EXPECT_TRUE(
2166       ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
2167 
2168   auto inferred_status_error0 = ShapeInference::InferCallShape(
2169       {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
2170   EXPECT_FALSE(inferred_status_error0.ok());
2171   EXPECT_THAT(inferred_status_error0.status().error_message(),
2172               HasSubstr("arity must match"));
2173 
2174   auto inferred_status_error1 = ShapeInference::InferCallShape(
2175       {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
2176   EXPECT_FALSE(inferred_status_error1.ok());
2177   EXPECT_THAT(inferred_status_error1.status().error_message(),
2178               HasSubstr("arity must match"));
2179 
2180   auto inferred_status_error2 = ShapeInference::InferCallShape(
2181       {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
2182   EXPECT_FALSE(inferred_status_error2.ok());
2183   EXPECT_THAT(inferred_status_error2.status().error_message(),
2184               HasSubstr("parameter must match argument"));
2185 }
2186 
TEST_F(ShapeInferenceTest,Transpose)2187 TEST_F(ShapeInferenceTest, Transpose) {
2188   Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
2189   auto inferred_shape_and_status =
2190       ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
2191   EXPECT_IS_OK(inferred_shape_and_status);
2192   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2193   EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
2194                                     ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
2195 }
2196 
TEST_F(ShapeInferenceTest,Rank1Transpose)2197 TEST_F(ShapeInferenceTest, Rank1Transpose) {
2198   Shape a_shape = ShapeUtil::MakeShape(F32, {5});
2199   auto inferred_shape_and_status =
2200       ShapeInference::InferTransposeShape(a_shape, {0});
2201   EXPECT_IS_OK(inferred_shape_and_status);
2202   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
2203   EXPECT_TRUE(
2204       ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
2205 }
2206 
TEST_F(ShapeInferenceTest,ConditionalPred)2207 TEST_F(ShapeInferenceTest, ConditionalPred) {
2208   auto inferred_status0 = ShapeInference::InferConditionalShape(
2209       pred_,
2210       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2211        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2212       {vector_32_, vector_64_});
2213   EXPECT_IS_OK(inferred_status0.status());
2214   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2215 
2216   auto inferred_status1 = ShapeInference::InferConditionalShape(
2217       pred_,
2218       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2219        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
2220       {matrix_32_48_, vector_32_});
2221   EXPECT_IS_OK(inferred_status1.status());
2222   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2223 
2224   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2225   auto inferred_status2 = ShapeInference::InferConditionalShape(
2226       pred_,
2227       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2228        ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2229       {matrix_32_48_, tuple_f32_v32});
2230   EXPECT_IS_OK(inferred_status2.status());
2231   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2232 
2233   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2234       f32_,
2235       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2236        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2237       {vector_32_, vector_64_});
2238   EXPECT_FALSE(inferred_status_error0.ok());
2239   EXPECT_THAT(inferred_status_error0.status().error_message(),
2240               HasSubstr("must be bool or int32"));
2241 
2242   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2243       pred_,
2244       {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2245        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2246       {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
2247   EXPECT_FALSE(inferred_status_error1.ok());
2248   EXPECT_THAT(inferred_status_error1.status().error_message(),
2249               HasSubstr("branch computation 0 must take 1 argument"));
2250 
2251   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2252       pred_,
2253       {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2254        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2255       {vector_32_, vector_64_});
2256   EXPECT_FALSE(inferred_status_error2.ok());
2257   EXPECT_THAT(inferred_status_error2.status().error_message(),
2258               HasSubstr("branch operand 0 must match the shape of the only "
2259                         "parameter of branch computation 0"));
2260 
2261   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2262       pred_,
2263       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2264        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
2265       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
2266   EXPECT_FALSE(inferred_status_error3.ok());
2267   EXPECT_THAT(inferred_status_error3.status().error_message(),
2268               HasSubstr("branch computation 1 must take 1 argument"));
2269 
2270   auto inferred_status_error4 = ShapeInference::InferConditionalShape(
2271       pred_,
2272       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2273        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2274       {vector_32_, vector_64_});
2275   EXPECT_FALSE(inferred_status_error4.ok());
2276   EXPECT_THAT(inferred_status_error4.status().error_message(),
2277               HasSubstr("branch operand 1 must match the shape of the only "
2278                         "parameter of branch computation 1"));
2279 
2280   auto inferred_status_error5 = ShapeInference::InferConditionalShape(
2281       pred_,
2282       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2283        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2284       {vector_32_, vector_64_});
2285   EXPECT_FALSE(inferred_status_error5.ok());
2286   EXPECT_THAT(inferred_status_error5.status().error_message(),
2287               HasSubstr("the result of branch 0 computation and branch 1 "
2288                         "computation must have the same shape"));
2289 }
2290 
TEST_F(ShapeInferenceTest,ConditionalIndexed)2291 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
2292   auto r0s32 = ShapeUtil::MakeShape(S32, {});
2293   auto inferred_status0 = ShapeInference::InferConditionalShape(
2294       r0s32,
2295       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2296        ShapeUtil::MakeProgramShape({vector_64_}, f32_),
2297        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2298       {vector_32_, vector_64_, vector_64_});
2299   EXPECT_IS_OK(inferred_status0.status());
2300   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
2301 
2302   auto inferred_status1 = ShapeInference::InferConditionalShape(
2303       r0s32,
2304       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
2305        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
2306        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
2307       {matrix_32_48_, vector_32_, matrix_32_48_});
2308   EXPECT_IS_OK(inferred_status1.status());
2309   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
2310 
2311   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
2312   auto inferred_status2 = ShapeInference::InferConditionalShape(
2313       r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
2314       {tuple_f32_v32});
2315   EXPECT_IS_OK(inferred_status2.status());
2316   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
2317 
2318   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
2319       pred_,
2320       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2321        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2322        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
2323       {vector_32_, vector_32_, vector_64_});
2324   EXPECT_FALSE(inferred_status_error0.ok());
2325   EXPECT_THAT(inferred_status_error0.status().error_message(),
2326               HasSubstr("2 == branch_computations.size()"));
2327 
2328   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
2329       r0s32,
2330       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
2331        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
2332        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
2333       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
2334        matrix_32_48_});
2335   EXPECT_FALSE(inferred_status_error1.ok());
2336   EXPECT_THAT(inferred_status_error1.status().error_message(),
2337               HasSubstr("branch computation 1 must take 1 argument"));
2338 
2339   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
2340       r0s32,
2341       {ShapeUtil::MakeProgramShape({r0s32}, f32_),
2342        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2343        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
2344       {r0s32, vector_32_, vector_64_});
2345   EXPECT_FALSE(inferred_status_error2.ok());
2346   EXPECT_THAT(inferred_status_error2.status().error_message(),
2347               HasSubstr("branch operand 2 must match the shape of the only "
2348                         "parameter of branch computation 2"));
2349 
2350   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
2351       r0s32,
2352       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2353        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2354        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
2355        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
2356       {vector_32_, vector_32_, vector_32_, vector_64_});
2357   EXPECT_FALSE(inferred_status_error3.ok());
2358   EXPECT_THAT(inferred_status_error3.status().error_message(),
2359               HasSubstr("the result of branch 0 computation and branch 3 "
2360                         "computation must have the same shape"));
2361 
2362   auto inferred_status_error4 =
2363       ShapeInference::InferConditionalShape(r0s32, {}, {});
2364   EXPECT_FALSE(inferred_status_error4.ok());
2365   EXPECT_THAT(inferred_status_error4.status().error_message(),
2366               HasSubstr("!branch_computations.empty()"));
2367 }
2368 
TEST_F(ShapeInferenceTest,ConditionalDynamic)2369 TEST_F(ShapeInferenceTest, ConditionalDynamic) {
2370   auto r0s32 = ShapeUtil::MakeShape(S32, {});
2371   auto static_shape = ShapeUtil::MakeShape(S32, {4}, {false});
2372   auto dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true});
2373   auto inferred_status0 = ShapeInference::InferConditionalShape(
2374       r0s32,
2375       {ShapeUtil::MakeProgramShape({vector_32_}, static_shape),
2376        ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape),
2377        ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2378       {vector_32_, vector_64_, vector_64_});
2379   EXPECT_IS_OK(inferred_status0.status());
2380   EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status0.ValueOrDie()));
2381 
2382   auto inferred_status1 = ShapeInference::InferConditionalShape(
2383       r0s32,
2384       {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape),
2385        ShapeUtil::MakeProgramShape({vector_64_}, static_shape),
2386        ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)},
2387       {vector_32_, vector_64_, vector_64_});
2388   EXPECT_IS_OK(inferred_status1.status());
2389   EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status1.ValueOrDie()));
2390 }
2391 
TEST_F(ShapeInferenceTest,BadSlice)2392 TEST_F(ShapeInferenceTest, BadSlice) {
2393   auto arg = ShapeUtil::MakeShape(F32, {4});
2394   StatusOr<Shape> statusor =
2395       ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
2396   ASSERT_FALSE(statusor.ok());
2397 
2398   LOG(INFO) << statusor.status();
2399 
2400   EXPECT_THAT(statusor.status().error_message(),
2401               HasSubstr("less than or equal to dimension size"))
2402       << statusor.status();
2403   EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
2404       << statusor.status();
2405 }
2406 
TEST_F(ShapeInferenceTest,BadSort)2407 TEST_F(ShapeInferenceTest, BadSort) {
2408   auto keys = ShapeUtil::MakeShape(F32, {4});
2409   auto values = ShapeUtil::MakeShape(F32, {5});
2410   StatusOr<Shape> statusor =
2411       ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
2412   EXPECT_FALSE(statusor.ok());
2413   EXPECT_THAT(statusor.status().error_message(),
2414               HasSubstr("dimensions must match"))
2415       << statusor.status();
2416 }
2417 
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)2418 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
2419   auto keys = ShapeUtil::MakeShape(F32, {4});
2420   auto values_good = ShapeUtil::MakeShape(F32, {4});
2421   auto values_bad = ShapeUtil::MakeShape(F32, {5});
2422   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2423       HloOpcode::kSort, {&keys, &values_good, &values_bad});
2424   EXPECT_FALSE(statusor.ok());
2425   EXPECT_THAT(statusor.status().error_message(),
2426               HasSubstr("dimensions must match"))
2427       << statusor.status();
2428 }
2429 
TEST_F(ShapeInferenceTest,SortManyValues)2430 TEST_F(ShapeInferenceTest, SortManyValues) {
2431   auto keys = ShapeUtil::MakeShape(F32, {4});
2432   auto values_s32 = ShapeUtil::MakeShape(S32, {4});
2433   auto values_u32 = ShapeUtil::MakeShape(U32, {4});
2434   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
2435       HloOpcode::kSort, {&keys, &values_s32, &values_u32});
2436   EXPECT_IS_OK(statusor);
2437   Shape inferred_shape = statusor.ValueOrDie();
2438   EXPECT_TRUE(ShapeUtil::Compatible(
2439       inferred_shape,
2440       ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
2441 }
2442 
2443 class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
2444  protected:
2445   const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
2446   const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
2447   const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
2448   const Shape s64_4d_tensor_10_9_8_7_1_ =
2449       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
2450   const Shape s64_4d_tensor_10_9_8_7_5_ =
2451       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
2452   const Shape s64_4d_tensor_5_10_9_7_6_ =
2453       ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
2454   const Shape s64_4d_tensor_10_9_5_7_6_ =
2455       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
2456   const Shape f32_5d_tensor_50_49_48_47_46_ =
2457       ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
2458   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
2459       {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
2460   const ProgramShape to_apply_ =
2461       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
2462 };
2463 
2464 // Shape inference tests for Gather.
2465 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGather)2466 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
2467   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2468                           ShapeInference::InferGatherShape(
2469                               matrix_64_48_, s64_vector_32_,
2470                               HloGatherInstruction::MakeGatherDimNumbers(
2471                                   /*offset_dims=*/{0},
2472                                   /*collapsed_slice_dims=*/{1},
2473                                   /*start_index_map=*/{1},
2474                                   /*index_vector_dim=*/1),
2475                               /*slice_sizes=*/{64, 1}));
2476   EXPECT_TRUE(
2477       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
2478       << ShapeUtil::HumanString(gather_shape);
2479 }
2480 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherV2)2481 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
2482   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2483                           ShapeInference::InferGatherShape(
2484                               matrix_64_48_, s64_vector_32_,
2485                               HloGatherInstruction::MakeGatherDimNumbers(
2486                                   /*offset_dims=*/{1},
2487                                   /*collapsed_slice_dims=*/{0},
2488                                   /*start_index_map=*/{0},
2489                                   /*index_vector_dim=*/1),
2490                               /*slice_sizes=*/{1, 48}));
2491   EXPECT_TRUE(
2492       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
2493       << ShapeUtil::HumanString(gather_shape);
2494 }
2495 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherNd)2496 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
2497   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2498                           ShapeInference::InferGatherShape(
2499                               matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2500                               HloGatherInstruction::MakeGatherDimNumbers(
2501                                   /*offset_dims=*/{4},
2502                                   /*collapsed_slice_dims=*/{0},
2503                                   /*start_index_map=*/{0},
2504                                   /*index_vector_dim=*/4),
2505                               /*slice_sizes=*/{1, 48}));
2506   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2507                                ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
2508       << ShapeUtil::HumanString(gather_shape);
2509 }
2510 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowBatchDynamicSlice)2511 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
2512   TF_ASSERT_OK_AND_ASSIGN(
2513       Shape gather_shape,
2514       ShapeInference::InferGatherShape(
2515           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2516           HloGatherInstruction::MakeGatherDimNumbers(
2517               /*offset_dims=*/{4, 5, 6, 7, 8},
2518               /*collapsed_slice_dims=*/{},
2519               /*start_index_map=*/{0, 1, 2, 3, 4},
2520               /*index_vector_dim=*/4),
2521           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2522   EXPECT_TRUE(ShapeUtil::Equal(
2523       gather_shape,
2524       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
2525       << ShapeUtil::HumanString(gather_shape);
2526 }
2527 
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherEntireDimension)2528 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherEntireDimension) {
2529   TF_ASSERT_OK_AND_ASSIGN(
2530       Shape gather_shape,
2531       ShapeInference::InferGatherShape(
2532           ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
2533           ShapeUtil::MakeShape(S64, {}),
2534           HloGatherInstruction::MakeGatherDimNumbers(
2535               /*offset_dims=*/{0, 1},
2536               /*collapsed_slice_dims=*/{0},
2537               /*start_index_map=*/{0},
2538               /*index_vector_dim=*/0),
2539           /*slice_sizes=*/{1, 2, 1}));
2540   EXPECT_TRUE(ShapeUtil::Equal(
2541       gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
2542       << ShapeUtil::HumanString(gather_shape);
2543 }
2544 
TEST_F(ScatterGatherShapeInferenceTest,DynamicGatherCollapsedDimension)2545 TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
2546   TF_ASSERT_OK_AND_ASSIGN(
2547       Shape gather_shape,
2548       ShapeInference::InferGatherShape(
2549           ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
2550           ShapeUtil::MakeShape(S64, {}),
2551           HloGatherInstruction::MakeGatherDimNumbers(
2552               /*offset_dims=*/{0, 1},
2553               /*collapsed_slice_dims=*/{0},
2554               /*start_index_map=*/{0},
2555               /*index_vector_dim=*/0),
2556           /*slice_sizes=*/{1, 2, 1}));
2557   EXPECT_TRUE(ShapeUtil::Equal(
2558       gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
2559       << ShapeUtil::HumanString(gather_shape);
2560 }
2561 
TEST_F(ScatterGatherShapeInferenceTest,DynamicIndices)2562 TEST_F(ScatterGatherShapeInferenceTest, DynamicIndices) {
2563   TF_ASSERT_OK_AND_ASSIGN(
2564       Shape gather_shape,
2565       ShapeInference::InferGatherShape(
2566           ShapeUtil::MakeShape(F32, {3, 2, 2}),
2567           ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
2568           HloGatherInstruction::MakeGatherDimNumbers(
2569               /*offset_dims=*/{2, 3},
2570               /*collapsed_slice_dims=*/{0},
2571               /*start_index_map=*/{0, 1},
2572               /*index_vector_dim=*/2),
2573           /*slice_sizes=*/{1, 2, 2}));
2574   EXPECT_TRUE(ShapeUtil::Equal(
2575       gather_shape,
2576       ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
2577       << ShapeUtil::HumanString(gather_shape);
2578 }
2579 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)2580 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
2581   TF_ASSERT_OK_AND_ASSIGN(
2582       Shape gather_shape,
2583       ShapeInference::InferGatherShape(
2584           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2585           HloGatherInstruction::MakeGatherDimNumbers(
2586               /*offset_dims=*/{4, 5, 6, 7, 8},
2587               /*collapsed_slice_dims=*/{},
2588               /*start_index_map=*/{0, 1, 2, 3, 4},
2589               /*index_vector_dim=*/2),
2590           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2591 
2592   EXPECT_TRUE(ShapeUtil::Equal(
2593       gather_shape,
2594       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2595       << ShapeUtil::HumanString(gather_shape);
2596 }
2597 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)2598 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
2599   TF_ASSERT_OK_AND_ASSIGN(
2600       Shape gather_shape,
2601       ShapeInference::InferGatherShape(
2602           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2603           HloGatherInstruction::MakeGatherDimNumbers(
2604               /*offset_dims=*/{4, 5, 6, 7, 8},
2605               /*collapsed_slice_dims=*/{},
2606               /*start_index_map=*/{0, 1, 2, 3, 4},
2607               /*index_vector_dim=*/0),
2608           /*slice_sizes=*/{30, 29, 28, 27, 26}));
2609 
2610   EXPECT_TRUE(ShapeUtil::Equal(
2611       gather_shape,
2612       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
2613       << ShapeUtil::HumanString(gather_shape);
2614 }
2615 
TEST_F(ScatterGatherShapeInferenceTest,NoOutputGatherDims)2616 TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
2617   // This is equivalent to a dynamic slice.
2618   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2619                           ShapeInference::InferGatherShape(
2620                               f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2621                               HloGatherInstruction::MakeGatherDimNumbers(
2622                                   /*offset_dims=*/{0, 1, 2, 3, 4},
2623                                   /*collapsed_slice_dims=*/{},
2624                                   /*start_index_map=*/{0, 1, 2, 3, 4},
2625                                   /*index_vector_dim=*/0),
2626                               /*slice_sizes=*/{30, 29, 28, 27, 26}));
2627 
2628   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2629                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
2630       << ShapeUtil::HumanString(gather_shape);
2631 }
2632 
TEST_F(ScatterGatherShapeInferenceTest,ScalarGatherIndices)2633 TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
2634   // The gather indices "tensor" is a scalar S here that's used to slice out
2635   // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
2636   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
2637                           ShapeInference::InferGatherShape(
2638                               f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2639                               HloGatherInstruction::MakeGatherDimNumbers(
2640                                   /*offset_dims=*/{0, 1, 2, 3},
2641                                   /*collapsed_slice_dims=*/{0},
2642                                   /*start_index_map=*/{0},
2643                                   /*index_vector_dim=*/0),
2644                               /*slice_sizes=*/{1, 30, 29, 28, 27}));
2645 
2646   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
2647                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
2648       << ShapeUtil::HumanString(gather_shape);
2649 }
2650 
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedTensorInput)2651 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
2652   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2653       tuple_shape_, s64_vector_32_,
2654       HloGatherInstruction::MakeGatherDimNumbers(
2655           /*offset_dims=*/{0},
2656           /*collapsed_slice_dims=*/{1},
2657           /*start_index_map=*/{1},
2658           /*index_vector_dim=*/1),
2659       /*slice_sizes=*/{64, 1});
2660   ASSERT_FALSE(statusor.ok());
2661   EXPECT_THAT(statusor.status().error_message(),
2662               HasSubstr("Expected array argument for input"))
2663       << statusor.status();
2664 }
2665 
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedGatherIndicesInput)2666 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
2667   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2668       s64_vector_32_, tuple_shape_,
2669       HloGatherInstruction::MakeGatherDimNumbers(
2670           /*offset_dims=*/{0},
2671           /*collapsed_slice_dims=*/{1},
2672           /*start_index_map=*/{1},
2673           /*index_vector_dim=*/0),
2674       /*slice_sizes=*/{64, 1});
2675   ASSERT_FALSE(statusor.ok());
2676   EXPECT_THAT(statusor.status().error_message(),
2677               HasSubstr("Expected array argument for gather indices"))
2678       << statusor.status();
2679 }
2680 
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointGatherIndicesInput)2681 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
2682   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2683       s64_vector_32_, vector_32_,
2684       HloGatherInstruction::MakeGatherDimNumbers(
2685           /*offset_dims=*/{0},
2686           /*collapsed_slice_dims=*/{1},
2687           /*start_index_map=*/{1},
2688           /*index_vector_dim=*/0),
2689       /*slice_sizes=*/{64, 1});
2690   ASSERT_FALSE(statusor.ok());
2691   EXPECT_THAT(statusor.status().error_message(),
2692               HasSubstr("Gather indices parameter must be an integral tensor"))
2693       << statusor.status();
2694 }
2695 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)2696 TEST_F(ScatterGatherShapeInferenceTest,
2697        InvalidGatherDimNumbers_NonAscendingWindowIndices) {
2698   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2699       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2700       HloGatherInstruction::MakeGatherDimNumbers(
2701           /*offset_dims=*/{4, 5, 6, 8, 7},
2702           /*collapsed_slice_dims=*/{},
2703           /*start_index_map=*/{0, 1, 2, 3, 4},
2704           /*index_vector_dim=*/4),
2705       /*slice_sizes=*/{30, 29, 28, 27, 26});
2706   ASSERT_FALSE(statusor.ok());
2707   EXPECT_THAT(
2708       statusor.status().error_message(),
2709       HasSubstr("Output window dimensions in gather op must be ascending"))
2710       << statusor.status();
2711 }
2712 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)2713 TEST_F(ScatterGatherShapeInferenceTest,
2714        InvalidGatherDimNumbers_RepeatedWindowIndices) {
2715   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2716       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2717       HloGatherInstruction::MakeGatherDimNumbers(
2718           /*offset_dims=*/{4, 5, 6, 7, 7},
2719           /*collapsed_slice_dims=*/{},
2720           /*start_index_map=*/{0, 1, 2, 3, 4},
2721           /*index_vector_dim=*/4),
2722       /*slice_sizes=*/{30, 29, 28, 27, 26});
2723   ASSERT_FALSE(statusor.ok());
2724   EXPECT_THAT(
2725       statusor.status().error_message(),
2726       HasSubstr("Output window dimensions in gather op must not repeat"))
2727       << statusor.status();
2728 }
2729 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)2730 TEST_F(ScatterGatherShapeInferenceTest,
2731        InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
2732   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2733       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2734       HloGatherInstruction::MakeGatherDimNumbers(
2735           /*offset_dims=*/{4, 5, 99, 100, 101},
2736           /*collapsed_slice_dims=*/{},
2737           /*start_index_map=*/{0, 1, 2, 3, 4},
2738           /*index_vector_dim=*/4),
2739       /*slice_sizes=*/{30, 29, 28, 27, 26});
2740   ASSERT_FALSE(statusor.ok());
2741   EXPECT_THAT(statusor.status().error_message(),
2742               HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2743       << statusor.status();
2744 }
2745 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2746 TEST_F(ScatterGatherShapeInferenceTest,
2747        InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2748   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2749       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2750       HloGatherInstruction::MakeGatherDimNumbers(
2751           /*offset_dims=*/{4, 5, 6, 7, 9},
2752           /*collapsed_slice_dims=*/{},
2753           /*start_index_map=*/{0, 1, 2, 3, 4},
2754           /*index_vector_dim=*/4),
2755       /*slice_sizes=*/{30, 29, 28, 27, 26});
2756   ASSERT_FALSE(statusor.ok());
2757   EXPECT_THAT(statusor.status().error_message(),
2758               HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2759       << statusor.status();
2760 }
2761 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2762 TEST_F(ScatterGatherShapeInferenceTest,
2763        InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2764   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2765       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2766       HloGatherInstruction::MakeGatherDimNumbers(
2767           /*offset_dims=*/{4, 5, 6, 7, 8},
2768           /*collapsed_slice_dims=*/{4},
2769           /*start_index_map=*/{0, 1, 2, 3, 4},
2770           /*index_vector_dim=*/4),
2771       /*slice_sizes=*/{30, 29, 28, 27, 26});
2772   ASSERT_FALSE(statusor.ok());
2773   EXPECT_THAT(
2774       statusor.status().error_message(),
2775       HasSubstr("All components of the offset index in a gather op must either "
2776                 "be a offset dimension or explicitly collapsed"))
2777       << statusor.status();
2778 }
2779 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2780 TEST_F(ScatterGatherShapeInferenceTest,
2781        InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2782   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2783       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2784       HloGatherInstruction::MakeGatherDimNumbers(
2785           /*offset_dims=*/{4, 5, 6, 7, 8},
2786           /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2787           /*start_index_map=*/{0, 1, 2, 3, 4},
2788           /*index_vector_dim=*/4),
2789       /*slice_sizes=*/{30, 29, 28, 27, 26});
2790   ASSERT_FALSE(statusor.ok());
2791   EXPECT_THAT(statusor.status().error_message(),
2792               HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2793                         "range is [0, 5), got: 19"))
2794       << statusor.status();
2795 }
2796 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2797 TEST_F(ScatterGatherShapeInferenceTest,
2798        InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2799   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2800       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2801       HloGatherInstruction::MakeGatherDimNumbers(
2802           /*offset_dims=*/{4, 5, 6, 7, 8},
2803           /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2804           /*start_index_map=*/{0, 1, 2, 3, 4},
2805           /*index_vector_dim=*/4),
2806       /*slice_sizes=*/{30, 29, 28, 27, 26});
2807   ASSERT_FALSE(statusor.ok());
2808   EXPECT_THAT(statusor.status().error_message(),
2809               HasSubstr("Repeated dimensions not allowed in "
2810                         "collapsed_slice_dims in gather op"))
2811       << statusor.status();
2812 }
2813 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2814 TEST_F(ScatterGatherShapeInferenceTest,
2815        InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2816   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2817       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2818       HloGatherInstruction::MakeGatherDimNumbers(
2819           /*offset_dims=*/{4, 5, 6, 7, 8},
2820           /*collapsed_slice_dims=*/{},
2821           /*start_index_map=*/{0, 1, 2, 3},
2822           /*index_vector_dim=*/4),
2823       /*slice_sizes=*/{30, 29, 28, 27, 26});
2824   ASSERT_FALSE(statusor.ok());
2825   EXPECT_THAT(statusor.status().error_message(),
2826               HasSubstr("Gather op has 4 elements in start_index_map and "
2827                         "the bound of dimension index_vector_dim=4 of "
2828                         "start_indices is 5. These two numbers must be equal."))
2829       << statusor.status();
2830 }
2831 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2832 TEST_F(ScatterGatherShapeInferenceTest,
2833        InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2834   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2835       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2836       HloGatherInstruction::MakeGatherDimNumbers(
2837           /*offset_dims=*/{4, 5, 6, 7, 8},
2838           /*collapsed_slice_dims=*/{},
2839           /*start_index_map=*/{0, 1, 2, 3, 7},
2840           /*index_vector_dim=*/4),
2841       /*slice_sizes=*/{30, 29, 28, 27, 26});
2842   ASSERT_FALSE(statusor.ok());
2843   EXPECT_THAT(statusor.status().error_message(),
2844               HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2845       << statusor.status();
2846 }
2847 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2848 TEST_F(ScatterGatherShapeInferenceTest,
2849        InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2850   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2851       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2852       HloGatherInstruction::MakeGatherDimNumbers(
2853           /*offset_dims=*/{4, 5, 6, 7, 8},
2854           /*collapsed_slice_dims=*/{},
2855           /*start_index_map=*/{0, 1, 2, 3, 3},
2856           /*index_vector_dim=*/4),
2857       /*slice_sizes=*/{30, 29, 28, 27, 26});
2858   ASSERT_FALSE(statusor.ok());
2859   EXPECT_THAT(
2860       statusor.status().error_message(),
2861       HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2862       << statusor.status();
2863 }
2864 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2865 TEST_F(ScatterGatherShapeInferenceTest,
2866        InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2867   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2868       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2869       HloGatherInstruction::MakeGatherDimNumbers(
2870           /*offset_dims=*/{4, 5, 6, 7, 8},
2871           /*collapsed_slice_dims=*/{2, 1},
2872           /*start_index_map=*/{0, 1, 2, 3, 4},
2873           /*index_vector_dim=*/4),
2874       /*slice_sizes=*/{1, 1, 28, 27, 26});
2875   ASSERT_FALSE(statusor.ok());
2876   EXPECT_THAT(statusor.status().error_message(),
2877               HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2878       << statusor.status();
2879 }
2880 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2881 TEST_F(ScatterGatherShapeInferenceTest,
2882        InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2883   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2884       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2885       HloGatherInstruction::MakeGatherDimNumbers(
2886           /*offset_dims=*/{4, 5, 6, 7},
2887           /*collapsed_slice_dims=*/{2},
2888           /*start_index_map=*/{0, 1, 2, 3, 4},
2889           /*index_vector_dim=*/4),
2890       /*slice_sizes=*/{30, 29, 1, 300, 26});
2891   ASSERT_FALSE(statusor.ok());
2892   EXPECT_THAT(statusor.status().error_message(),
2893               HasSubstr("Slice size at index 3 in gather op is out of range, "
2894                         "must be within [0, 48), got 300."))
2895       << statusor.status();
2896 }
2897 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2898 TEST_F(ScatterGatherShapeInferenceTest,
2899        InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2900   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2901       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2902       HloGatherInstruction::MakeGatherDimNumbers(
2903           /*offset_dims=*/{4, 5, 6, 7, 8},
2904           /*collapsed_slice_dims=*/{},
2905           /*start_index_map=*/{0, 1, 2, 3, 4},
2906           /*index_vector_dim=*/4),
2907       /*slice_sizes=*/{30, 29, 28, 26});
2908   ASSERT_FALSE(statusor.ok());
2909   EXPECT_THAT(
2910       statusor.status().error_message(),
2911       HasSubstr("Gather op must have one slice size for every input dimension"))
2912       << statusor.status();
2913 }
2914 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2915 TEST_F(ScatterGatherShapeInferenceTest,
2916        InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2917   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2918       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2919       HloGatherInstruction::MakeGatherDimNumbers(
2920           /*offset_dims=*/{4, 5, 6, 7},
2921           /*collapsed_slice_dims=*/{1},
2922           /*start_index_map=*/{0, 1, 2, 3, 4},
2923           /*index_vector_dim=*/4),
2924       /*slice_sizes=*/{30, 29, 28, 26, 20});
2925   ASSERT_FALSE(statusor.ok());
2926   EXPECT_THAT(
2927       statusor.status().error_message(),
2928       HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
2929                 "but bound is 29 for index 1 at position 0."))
2930       << statusor.status();
2931 }
2932 
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2933 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2934   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2935       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2936       HloGatherInstruction::MakeGatherDimNumbers(
2937           /*offset_dims=*/{4, 5, 6, 7, 8},
2938           /*collapsed_slice_dims=*/{},
2939           /*start_index_map=*/{0, 1, 2, 3, 4},
2940           /*index_vector_dim=*/32),
2941       /*slice_sizes=*/{30, 29, 28, 27, 26});
2942 
2943   ASSERT_FALSE(statusor.ok());
2944   EXPECT_THAT(statusor.status().error_message(),
2945               HasSubstr("Gather index leaf dimension must be within [0, "
2946                         "rank(start_indices) + 1)"))
2947       << statusor.status();
2948 }
2949 
2950 // Shape inference tests for Scatter.
2951 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdates)2952 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
2953   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2954                           ShapeInference::InferScatterShape(
2955                               matrix_64_48_, s64_vector_32_,
2956                               ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
2957                               HloScatterInstruction::MakeScatterDimNumbers(
2958                                   /*update_window_dims=*/{0},
2959                                   /*inserted_window_dims=*/{1},
2960                                   /*scatter_dims_to_operand_dims=*/{1},
2961                                   /*index_vector_dim=*/1)));
2962   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2963       << ShapeUtil::HumanString(scatter_shape);
2964 }
2965 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdatesV2)2966 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
2967   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2968                           ShapeInference::InferScatterShape(
2969                               matrix_64_48_, s64_vector_32_,
2970                               ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
2971                               HloScatterInstruction::MakeScatterDimNumbers(
2972                                   /*update_window_dims=*/{1},
2973                                   /*inserted_window_dims=*/{0},
2974                                   /*scatter_dims_to_operand_dims=*/{0},
2975                                   /*index_vector_dim=*/1)));
2976   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2977       << ShapeUtil::HumanString(scatter_shape);
2978 }
2979 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdates)2980 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
2981   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2982                           ShapeInference::InferScatterShape(
2983                               matrix_64_48_, s64_vector_32_,
2984                               ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
2985                               HloScatterInstruction::MakeScatterDimNumbers(
2986                                   /*update_window_dims=*/{0},
2987                                   /*inserted_window_dims=*/{1},
2988                                   /*scatter_dims_to_operand_dims=*/{1},
2989                                   /*index_vector_dim=*/1)));
2990   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2991       << ShapeUtil::HumanString(scatter_shape);
2992 }
2993 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdatesV2)2994 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
2995   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2996                           ShapeInference::InferScatterShape(
2997                               matrix_64_48_, s64_vector_32_,
2998                               ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
2999                               HloScatterInstruction::MakeScatterDimNumbers(
3000                                   /*update_window_dims=*/{1},
3001                                   /*inserted_window_dims=*/{0},
3002                                   /*scatter_dims_to_operand_dims=*/{0},
3003                                   /*index_vector_dim=*/1)));
3004   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3005       << ShapeUtil::HumanString(scatter_shape);
3006 }
3007 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)3008 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
3009   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3010       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
3011       to_apply_,
3012       HloScatterInstruction::MakeScatterDimNumbers(
3013           /*update_window_dims=*/{0},
3014           /*inserted_window_dims=*/{1},
3015           /*scatter_dims_to_operand_dims=*/{1},
3016           /*index_vector_dim=*/1));
3017   ASSERT_FALSE(statusor.ok());
3018   EXPECT_THAT(
3019       statusor.status().error_message(),
3020       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3021                 "the bounds of the corresponding dimensions of operand."))
3022       << statusor.status();
3023 }
3024 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)3025 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
3026   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3027       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
3028       to_apply_,
3029       HloScatterInstruction::MakeScatterDimNumbers(
3030           /*update_window_dims=*/{1},
3031           /*inserted_window_dims=*/{0},
3032           /*scatter_dims_to_operand_dims=*/{1},
3033           /*index_vector_dim=*/1));
3034   ASSERT_FALSE(statusor.ok());
3035   EXPECT_THAT(
3036       statusor.status().error_message(),
3037       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3038                 "the bounds of the corresponding dimensions of operand."))
3039       << statusor.status();
3040 }
3041 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)3042 TEST_F(ScatterGatherShapeInferenceTest,
3043        TfScatterWithUpdatesNotMatchingIndices) {
3044   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3045       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
3046       to_apply_,
3047       HloScatterInstruction::MakeScatterDimNumbers(
3048           /*update_window_dims=*/{0},
3049           /*inserted_window_dims=*/{1},
3050           /*scatter_dims_to_operand_dims=*/{1},
3051           /*index_vector_dim=*/1));
3052   ASSERT_FALSE(statusor.ok());
3053   EXPECT_THAT(
3054       statusor.status().error_message(),
3055       HasSubstr(
3056           "Bounds of the scatter dimensions of updates must be same as the "
3057           "bounds of the corresponding dimensions of scatter indices."))
3058       << statusor.status();
3059 }
3060 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)3061 TEST_F(ScatterGatherShapeInferenceTest,
3062        TfScatterWithUpdatesNotMatchingIndicesV2) {
3063   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3064       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
3065       to_apply_,
3066       HloScatterInstruction::MakeScatterDimNumbers(
3067           /*update_window_dims=*/{1},
3068           /*inserted_window_dims=*/{0},
3069           /*scatter_dims_to_operand_dims=*/{1},
3070           /*index_vector_dim=*/1));
3071   ASSERT_FALSE(statusor.ok());
3072   EXPECT_THAT(
3073       statusor.status().error_message(),
3074       HasSubstr(
3075           "Bounds of the scatter dimensions of updates must be same as the "
3076           "bounds of the corresponding dimensions of scatter indices."))
3077       << statusor.status();
3078 }
3079 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdates)3080 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
3081   TF_ASSERT_OK_AND_ASSIGN(
3082       Shape scatter_shape,
3083       ShapeInference::InferScatterShape(
3084           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3085           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
3086           HloScatterInstruction::MakeScatterDimNumbers(
3087               /*update_window_dims=*/{4},
3088               /*inserted_window_dims=*/{0},
3089               /*scatter_dims_to_operand_dims=*/{0},
3090               /*index_vector_dim=*/4)));
3091   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3092       << ShapeUtil::HumanString(scatter_shape);
3093 }
3094 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdatesV2)3095 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
3096   TF_ASSERT_OK_AND_ASSIGN(
3097       Shape scatter_shape,
3098       ShapeInference::InferScatterShape(
3099           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3100           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
3101           HloScatterInstruction::MakeScatterDimNumbers(
3102               /*update_window_dims=*/{4},
3103               /*inserted_window_dims=*/{1},
3104               /*scatter_dims_to_operand_dims=*/{0},
3105               /*index_vector_dim=*/4)));
3106   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3107       << ShapeUtil::HumanString(scatter_shape);
3108 }
3109 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdates)3110 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
3111   TF_ASSERT_OK_AND_ASSIGN(
3112       Shape scatter_shape,
3113       ShapeInference::InferScatterShape(
3114           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3115           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
3116           HloScatterInstruction::MakeScatterDimNumbers(
3117               /*update_window_dims=*/{4},
3118               /*inserted_window_dims=*/{0},
3119               /*scatter_dims_to_operand_dims=*/{0},
3120               /*index_vector_dim=*/4)));
3121   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3122       << ShapeUtil::HumanString(scatter_shape);
3123 }
3124 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)3125 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
3126   TF_ASSERT_OK_AND_ASSIGN(
3127       Shape scatter_shape,
3128       ShapeInference::InferScatterShape(
3129           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3130           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
3131           HloScatterInstruction::MakeScatterDimNumbers(
3132               /*update_window_dims=*/{4},
3133               /*inserted_window_dims=*/{1},
3134               /*scatter_dims_to_operand_dims=*/{0},
3135               /*index_vector_dim=*/4)));
3136   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
3137       << ShapeUtil::HumanString(scatter_shape);
3138 }
3139 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)3140 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
3141   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3142       matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3143       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
3144       HloScatterInstruction::MakeScatterDimNumbers(
3145           /*update_window_dims=*/{4},
3146           /*inserted_window_dims=*/{1},
3147           /*scatter_dims_to_operand_dims=*/{0},
3148           /*index_vector_dim=*/4));
3149   ASSERT_FALSE(statusor.ok());
3150   EXPECT_THAT(
3151       statusor.status().error_message(),
3152       HasSubstr("Bounds of the window dimensions of updates must not exceed "
3153                 "the bounds of the corresponding dimensions of operand."))
3154       << statusor.status();
3155 }
3156 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)3157 TEST_F(ScatterGatherShapeInferenceTest,
3158        TfScatterNdWithUpdatesNotMatchingIndices) {
3159   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3160       matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
3161       ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
3162       HloScatterInstruction::MakeScatterDimNumbers(
3163           /*update_window_dims=*/{4},
3164           /*inserted_window_dims=*/{1},
3165           /*scatter_dims_to_operand_dims=*/{0},
3166           /*index_vector_dim=*/4));
3167   ASSERT_FALSE(statusor.ok());
3168   EXPECT_THAT(
3169       statusor.status().error_message(),
3170       HasSubstr(
3171           "Bounds of the scatter dimensions of updates must be same as the "
3172           "bounds of the corresponding dimensions of scatter indices."))
3173       << statusor.status();
3174 }
3175 
TEST_F(ScatterGatherShapeInferenceTest,TfBatchDynamicUpdateSlice)3176 TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
3177   TF_ASSERT_OK_AND_ASSIGN(
3178       Shape scatter_shape,
3179       ShapeInference::InferScatterShape(
3180           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3181           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
3182           to_apply_,
3183           HloScatterInstruction::MakeScatterDimNumbers(
3184               /*update_window_dims=*/{4, 5, 6, 7, 8},
3185               /*inserted_window_dims=*/{},
3186               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3187               /*index_vector_dim=*/4)));
3188   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3189       << ShapeUtil::HumanString(scatter_shape);
3190 }
3191 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDim)3192 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
3193   TF_ASSERT_OK_AND_ASSIGN(
3194       Shape scatter_shape,
3195       ShapeInference::InferScatterShape(
3196           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
3197           ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3198           to_apply_,
3199           HloScatterInstruction::MakeScatterDimNumbers(
3200               /*update_window_dims=*/{4, 5, 6, 7, 8},
3201               /*inserted_window_dims=*/{},
3202               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3203               /*index_vector_dim=*/2)));
3204 
3205   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3206       << ShapeUtil::HumanString(scatter_shape);
3207 }
3208 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)3209 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
3210   TF_ASSERT_OK_AND_ASSIGN(
3211       Shape scatter_shape,
3212       ShapeInference::InferScatterShape(
3213           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
3214           ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
3215           to_apply_,
3216           HloScatterInstruction::MakeScatterDimNumbers(
3217               /*update_window_dims=*/{4, 5, 6, 7, 8},
3218               /*inserted_window_dims=*/{},
3219               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3220               /*index_vector_dim=*/0)));
3221 
3222   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3223       << ShapeUtil::HumanString(scatter_shape);
3224 }
3225 
TEST_F(ScatterGatherShapeInferenceTest,NoUpdateScatterDims)3226 TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
3227   // This is equivalent to a dynamic update slice.
3228   TF_ASSERT_OK_AND_ASSIGN(
3229       Shape scatter_shape,
3230       ShapeInference::InferScatterShape(
3231           f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
3232           ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
3233           HloScatterInstruction::MakeScatterDimNumbers(
3234               /*update_window_dims=*/{0, 1, 2, 3, 4},
3235               /*inserted_window_dims=*/{},
3236               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3237               /*index_vector_dim=*/0)));
3238 
3239   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3240       << ShapeUtil::HumanString(scatter_shape);
3241 }
3242 
TEST_F(ScatterGatherShapeInferenceTest,ScalarScatterIndices)3243 TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
3244   // The scalar indices "tensor" is a scalar S here that's used to update a
3245   // [30,29,28,27] shaped tensor within the operand at position S.
3246   TF_ASSERT_OK_AND_ASSIGN(
3247       Shape scatter_shape,
3248       ShapeInference::InferScatterShape(
3249           f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3250           ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3251           HloScatterInstruction::MakeScatterDimNumbers(
3252               /*update_window_dims=*/{0, 1, 2, 3},
3253               /*inserted_window_dims=*/{0},
3254               /*scatter_dims_to_operand_dims=*/{0},
3255               /*index_vector_dim=*/0)));
3256 
3257   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
3258       << ShapeUtil::HumanString(scatter_shape);
3259 }
3260 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedTensorInput)3261 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
3262   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3263       tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
3264       HloScatterInstruction::MakeScatterDimNumbers(
3265           /*update_window_dims=*/{0},
3266           /*inserted_window_dims=*/{1},
3267           /*scatter_dims_to_operand_dims=*/{1},
3268           /*index_vector_dim=*/1));
3269   ASSERT_FALSE(statusor.ok());
3270   EXPECT_THAT(statusor.status().error_message(),
3271               HasSubstr("Expected array argument for operand"))
3272       << statusor.status();
3273 }
3274 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)3275 TEST_F(ScatterGatherShapeInferenceTest,
3276        ScatterWithTupleShapedScatterIndicesInput) {
3277   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3278       s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
3279       HloScatterInstruction::MakeScatterDimNumbers(
3280           /*update_window_dims=*/{0},
3281           /*inserted_window_dims=*/{1},
3282           /*scatter_dims_to_operand_dims=*/{1},
3283           /*index_vector_dim=*/0));
3284   ASSERT_FALSE(statusor.ok());
3285   EXPECT_THAT(statusor.status().error_message(),
3286               HasSubstr("Expected array argument for scatter indices"))
3287       << statusor.status();
3288 }
3289 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)3290 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
3291   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3292       s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
3293       HloScatterInstruction::MakeScatterDimNumbers(
3294           /*update_window_dims=*/{0},
3295           /*inserted_window_dims=*/{1},
3296           /*scatter_dims_to_operand_dims=*/{1},
3297           /*index_vector_dim=*/0));
3298   ASSERT_FALSE(statusor.ok());
3299   EXPECT_THAT(statusor.status().error_message(),
3300               HasSubstr("Expected array argument for updates"))
3301       << statusor.status();
3302 }
3303 
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointScatterIndicesInput)3304 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
3305   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3306       s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
3307       HloScatterInstruction::MakeScatterDimNumbers(
3308           /*update_window_dims=*/{0},
3309           /*inserted_window_dims=*/{1},
3310           /*scatter_dims_to_operand_dims=*/{1},
3311           /*index_vector_dim=*/0));
3312   ASSERT_FALSE(statusor.ok());
3313   EXPECT_THAT(statusor.status().error_message(),
3314               HasSubstr("Scatter indices parameter must be an integral tensor"))
3315       << statusor.status();
3316 }
3317 
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)3318 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
3319   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3320       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3321       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3322       HloScatterInstruction::MakeScatterDimNumbers(
3323           /*update_window_dims=*/{4, 5, 6},
3324           /*inserted_window_dims=*/{1, 2},
3325           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3326           /*index_vector_dim=*/10));
3327   ASSERT_FALSE(statusor.ok());
3328   EXPECT_THAT(statusor.status().error_message(),
3329               HasSubstr("Scatter index leaf dimension must be within [0, "
3330                         "rank(scatter_indices) + 1)"))
3331       << statusor.status();
3332 }
3333 
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdates)3334 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
3335   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3336       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3337       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
3338       HloScatterInstruction::MakeScatterDimNumbers(
3339           /*update_window_dims=*/{4, 5, 6},
3340           /*inserted_window_dims=*/{1, 2},
3341           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3342           /*index_vector_dim=*/4));
3343   ASSERT_FALSE(statusor.ok());
3344   EXPECT_THAT(statusor.status().error_message(),
3345               HasSubstr("Updates tensor must be of rank 7; got 8."))
3346       << statusor.status();
3347 }
3348 
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdateComputation)3349 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
3350   const ProgramShape invalid_update_computation =
3351       ShapeUtil::MakeProgramShape({f32_}, f32_);
3352   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3353       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3354       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
3355       invalid_update_computation,
3356       HloScatterInstruction::MakeScatterDimNumbers(
3357           /*update_window_dims=*/{4, 5, 6},
3358           /*inserted_window_dims=*/{1, 2},
3359           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3360           /*index_vector_dim=*/4));
3361   ASSERT_FALSE(statusor.ok());
3362   EXPECT_THAT(
3363       statusor.status().error_message(),
3364       HasSubstr("Reduction function must take 2 parameters, but takes 1"))
3365       << statusor.status();
3366 }
3367 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)3368 TEST_F(ScatterGatherShapeInferenceTest,
3369        InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
3370   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3371       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3372       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3373       HloScatterInstruction::MakeScatterDimNumbers(
3374           /*update_window_dims=*/{4, 5, 6, 8, 7},
3375           /*inserted_window_dims=*/{},
3376           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3377           /*index_vector_dim=*/4));
3378   ASSERT_FALSE(statusor.ok());
3379   EXPECT_THAT(statusor.status().error_message(),
3380               HasSubstr("update_window_dims in scatter op must be sorted"))
3381       << statusor.status();
3382 }
3383 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)3384 TEST_F(ScatterGatherShapeInferenceTest,
3385        InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
3386   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3387       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3388       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3389       HloScatterInstruction::MakeScatterDimNumbers(
3390           /*update_window_dims=*/{4, 5, 6, 7, 7},
3391           /*inserted_window_dims=*/{},
3392           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3393           /*index_vector_dim=*/4));
3394   ASSERT_FALSE(statusor.ok());
3395   EXPECT_THAT(statusor.status().error_message(),
3396               HasSubstr("update_window_dims in scatter op must not repeat"))
3397       << statusor.status();
3398 }
3399 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)3400 TEST_F(ScatterGatherShapeInferenceTest,
3401        InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
3402   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3403       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3404       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
3405       HloScatterInstruction::MakeScatterDimNumbers(
3406           /*update_window_dims=*/{4, 5, 6, 7, 9},
3407           /*inserted_window_dims=*/{},
3408           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3409           /*index_vector_dim=*/4));
3410   ASSERT_FALSE(statusor.ok());
3411   EXPECT_THAT(statusor.status().error_message(),
3412               HasSubstr("Invalid update_window_dims set in scatter op; valid "
3413                         "range is [0, 9)"))
3414       << statusor.status();
3415 }
3416 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)3417 TEST_F(ScatterGatherShapeInferenceTest,
3418        InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
3419   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3420       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3421       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3422       HloScatterInstruction::MakeScatterDimNumbers(
3423           /*update_window_dims=*/{4, 5, 6},
3424           /*inserted_window_dims=*/{2, 1},
3425           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3426           /*index_vector_dim=*/4));
3427   ASSERT_FALSE(statusor.ok());
3428   EXPECT_THAT(statusor.status().error_message(),
3429               HasSubstr("inserted_window_dims in scatter op must be sorted"))
3430       << statusor.status();
3431 }
3432 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)3433 TEST_F(ScatterGatherShapeInferenceTest,
3434        InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
3435   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3436       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3437       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3438       HloScatterInstruction::MakeScatterDimNumbers(
3439           /*update_window_dims=*/{4, 5, 6},
3440           /*inserted_window_dims=*/{1, 1},
3441           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3442           /*index_vector_dim=*/4));
3443   ASSERT_FALSE(statusor.ok());
3444   EXPECT_THAT(statusor.status().error_message(),
3445               HasSubstr("inserted_window_dims in scatter op must not repeat"))
3446       << statusor.status();
3447 }
3448 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)3449 TEST_F(ScatterGatherShapeInferenceTest,
3450        InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
3451   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3452       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3453       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3454       HloScatterInstruction::MakeScatterDimNumbers(
3455           /*update_window_dims=*/{4, 5, 6},
3456           /*inserted_window_dims=*/{1, 5},
3457           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
3458           /*index_vector_dim=*/4));
3459   ASSERT_FALSE(statusor.ok());
3460   EXPECT_THAT(statusor.status().error_message(),
3461               HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
3462                         "range is [0, 5)"))
3463       << statusor.status();
3464 }
3465 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)3466 TEST_F(ScatterGatherShapeInferenceTest,
3467        InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
3468   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3469       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3470       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3471       HloScatterInstruction::MakeScatterDimNumbers(
3472           /*update_window_dims=*/{4, 5, 6},
3473           /*inserted_window_dims=*/{1, 2},
3474           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
3475           /*index_vector_dim=*/4));
3476   ASSERT_FALSE(statusor.ok());
3477   EXPECT_THAT(
3478       statusor.status().error_message(),
3479       HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
3480                 "the bound of dimension index_vector_dim=4 of scatter_indices "
3481                 "is 5. These two numbers must be equal"))
3482       << statusor.status();
3483 }
3484 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)3485 TEST_F(ScatterGatherShapeInferenceTest,
3486        InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
3487   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3488       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3489       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3490       HloScatterInstruction::MakeScatterDimNumbers(
3491           /*update_window_dims=*/{4, 5, 6},
3492           /*inserted_window_dims=*/{1, 2},
3493           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
3494           /*index_vector_dim=*/4));
3495   ASSERT_FALSE(statusor.ok());
3496   EXPECT_THAT(statusor.status().error_message(),
3497               HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
3498                         "is [0, 5), got: 4->10"))
3499       << statusor.status();
3500 }
3501 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)3502 TEST_F(ScatterGatherShapeInferenceTest,
3503        InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
3504   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3505       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
3506       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
3507       HloScatterInstruction::MakeScatterDimNumbers(
3508           /*update_window_dims=*/{4, 5, 6},
3509           /*inserted_window_dims=*/{1, 2},
3510           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
3511           /*index_vector_dim=*/4));
3512   ASSERT_FALSE(statusor.ok());
3513   EXPECT_THAT(
3514       statusor.status().error_message(),
3515       HasSubstr(
3516           "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
3517       << statusor.status();
3518 }
3519 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)3520 TEST_F(ScatterGatherShapeInferenceTest,
3521        InvalidScatterDimNumbers_InsufficientWindowDims) {
3522   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
3523       f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
3524       ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
3525       HloScatterInstruction::MakeScatterDimNumbers(
3526           /*update_window_dims=*/{0, 1, 2, 3},
3527           /*inserted_window_dims=*/{},
3528           /*scatter_dims_to_operand_dims=*/{0},
3529           /*index_vector_dim=*/0));
3530   ASSERT_FALSE(statusor.ok());
3531   EXPECT_THAT(
3532       statusor.status().error_message(),
3533       HasSubstr(
3534           "Scatter op has window of size 4; doesn't match operand of rank 5."))
3535       << statusor.status();
3536 }
3537 
3538 }  // namespace
3539 }  // namespace xla
3540