/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/shape_inference.h" #include #include #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { using ::testing::ContainsRegex; using ::testing::HasSubstr; class ShapeInferenceTest : public ::testing::Test { protected: // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); const Shape f16_ = ShapeUtil::MakeShape(F16, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); // Some handy vector and matrix shapes of F32 type. // Suffix: vector_length_, matrix_rows_cols_ const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32}); const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64}); const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48}); const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64}); const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48}); // Some handy S32 arrays. const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64}); }; // Subclass for testing InferReduceShape. class ReduceShapeInferenceTest : public ShapeInferenceTest { protected: // Helper that runs reduce shape inference with the input 'arg' and given // dimensions to reduce, and checks the inferred shape is as expected. The // element type here is hard-coded to F32. void ExpectInferredReduceShape( const Shape& expected_inferred_shape, const Shape& arg, absl::Span dimensions_to_reduce) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( {&arg, &f32_}, dimensions_to_reduce, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape, inferred_status.ValueOrDie())); } }; // Subclass for testing InferSelectAndScatterShape. class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { protected: SelectAndScatterShapeInferenceTest() { operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16}); source_shape_ = ShapeUtil::MakeShape(F32, {4, 8}); WindowDimension dim; dim.set_size(2); dim.set_stride(2); dim.set_padding_low(0); dim.set_padding_high(0); dim.set_window_dilation(1); dim.set_base_dilation(1); *window_.add_dimensions() = dim; *window_.add_dimensions() = dim; init_value_shape_ = ShapeUtil::MakeShape(F32, {}); select_program_shape_ = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); scatter_program_shape_ = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); } Shape operand_shape_; Shape source_shape_; Window window_; Shape init_value_shape_; ProgramShape select_program_shape_; ProgramShape scatter_program_shape_; }; TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Expected array argument for select")); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT( inferred_status.status().error_message(), HasSubstr("Operands to select and predicate must be the same shape")); } TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT( inferred_status_error3.status().error_message(), HasSubstr("Operands to select and predicate must be the same shape")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("Expected array argument for select pred")); } TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Clamp with different shapes")); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch ASSERT_FALSE( ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) .ok()); // Dimension mismatch, where one operand is a scalar ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, vector_64_, vector_32_, f32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, vector_64_, f32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, absl::Span bcast) { return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); // Component types must match. ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); // Only F32->C64 and F64->C128 supported. ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok()); // Validate correct uses. Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {}))); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64}); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_64_, matrix_32_64_, {1})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, vector_64_, {1})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, matrix_32_64_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {})); ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {}))); } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { StatusOr result = ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); } TEST_F(ShapeInferenceTest, ReduceWindowInHalf) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8}); Window window; WindowDimension dim; dim.set_size(2); dim.set_stride(2); dim.set_padding_low(0); dim.set_padding_high(0); dim.set_window_dilation(1); dim.set_base_dilation(1); *window.add_dimensions() = dim; *window.add_dimensions() = dim; Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2}); Shape init_value_shape = ShapeUtil::MakeShape(F32, {}); Shape float_scalar = ShapeUtil::MakeShape(F32, {}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); auto inferred_status = ShapeInference::InferReduceWindowShape( matrix_shape, init_value_shape, window, to_apply); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred)); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) { auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_IS_OK(inferred_status_ok.status()); Shape inferred = inferred_status_ok.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred)); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) { Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6}); auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_, window_, source_shape_fail, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), HasSubstr("Source shape does not match")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_); auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), HasSubstr("Select function must take 2 parameters")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_); auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), HasSubstr("Select function must have rank-0 PRED")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_); auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), HasSubstr("Select function's first parameter")); } TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) { ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape( {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_); auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape( operand_shape_, select_program_shape_fail, window_, source_shape_, init_value_shape_, scatter_program_shape_); ASSERT_FALSE(inferred_status_fail.ok()); ASSERT_THAT(inferred_status_fail.status().error_message(), HasSubstr("Select function's second parameter")); } TEST_F(ShapeInferenceTest, AllGatherStart) { const Shape operand = ShapeUtil::MakeShape(F32, {1, 8, 4}); const Shape expected_shape = ShapeUtil::MakeTupleShape( {operand, ShapeUtil::MakeShape(F32, {8, 8, 4})}); auto inferred_ag_shape = ShapeInference::InferAllGatherStartShape( {&operand}, /*all_gather_dimension=*/0, /*shard_count=*/8); EXPECT_TRUE(inferred_ag_shape.ok()); EXPECT_TRUE(ShapeUtil::Equal(inferred_ag_shape.ValueOrDie(), expected_shape)); } TEST_F(ShapeInferenceTest, AllGatherDone) { const Shape input_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1, 8, 4}), ShapeUtil::MakeShape(F32, {8, 8, 4})}); const Shape expected_shape = ShapeUtil::MakeShape(F32, {8, 8, 4}); auto inferred_ag_done_shape = ShapeInference::InferAllGatherDoneShape(input_shape); EXPECT_TRUE(inferred_ag_done_shape.ok()); EXPECT_TRUE( ShapeUtil::Equal(inferred_ag_done_shape.ValueOrDie(), expected_shape)); } TEST_F(ShapeInferenceTest, Convolve) { ConvolutionDimensionNumbers dnums; // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); dnums.add_input_spatial_dimensions(2); dnums.add_output_spatial_dimensions(2); dnums.add_input_spatial_dimensions(3); dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); Window window; auto dim0 = window.add_dimensions(); auto dim1 = window.add_dimensions(); dim0->set_size(3); dim0->set_stride(2); dim0->set_padding_low(1); dim0->set_padding_high(1); dim0->set_window_dilation(1); dim0->set_base_dilation(1); dim1->set_size(2); dim1->set_stride(1); dim1->set_padding_low(0); dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { ConvolutionDimensionNumbers dnums; // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); dnums.add_input_spatial_dimensions(2); dnums.add_output_spatial_dimensions(2); dnums.add_input_spatial_dimensions(3); dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); Window window; auto dim0 = window.add_dimensions(); dim0->set_size(3); dim0->set_stride(3); dim0->set_padding_low(0); dim0->set_padding_high(0); dim0->set_window_dilation(6); dim0->set_base_dilation(1); auto dim1 = window.add_dimensions(); dim1->set_size(2); dim1->set_stride(1); dim1->set_padding_low(2); dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { ConvolutionDimensionNumbers dnums; // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); dnums.add_input_spatial_dimensions(2); dnums.add_output_spatial_dimensions(2); dnums.add_input_spatial_dimensions(3); dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); Window window; auto dim0 = window.add_dimensions(); dim0->set_size(4); dim0->set_stride(3); dim0->set_padding_low(0); dim0->set_padding_high(0); dim0->set_window_dilation(1); dim0->set_base_dilation(6); auto dim1 = window.add_dimensions(); dim1->set_size(2); dim1->set_stride(1); dim1->set_padding_low(2); dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { // Dimension order for this test: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(3); dnums.set_output_batch_dimension(3); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); dnums.add_input_spatial_dimensions(0); dnums.add_output_spatial_dimensions(0); dnums.add_input_spatial_dimensions(1); dnums.add_output_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 dnums.set_kernel_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); Window window; auto dim0 = window.add_dimensions(); auto dim1 = window.add_dimensions(); dim0->set_size(2); dim0->set_stride(1); dim0->set_padding_low(0); dim0->set_padding_high(0); dim1->set_size(3); dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); } TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.add_input_spatial_dimensions(2); dnums.add_input_spatial_dimensions(3); dnums.set_kernel_input_feature_dimension(0); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(2); dnums.add_kernel_spatial_dimensions(3); dnums.set_output_batch_dimension(0); dnums.set_output_feature_dimension(1); dnums.add_output_spatial_dimensions(2); dnums.add_output_spatial_dimensions(3); Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4}); Window window; auto dim0 = window.add_dimensions(); auto dim1 = window.add_dimensions(); dim0->set_size(4); dim1->set_size(4); dim0->set_padding_low(0); dim0->set_padding_high(2); dim1->set_padding_low(2); dim1->set_padding_high(1); dim0->set_stride(1); dim1->set_stride(1); dim0->set_window_dilation(3); dim1->set_window_dilation(2); auto inferred_status = ShapeInference::InferConvolveShape( lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, window, dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("to be a multiple of batch group count")); } struct ConvolveArgs { Shape lhs_shape; Shape rhs_shape; ConvolutionDimensionNumbers dnums; Window window; }; ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) { ConvolveArgs args; ConvolutionDimensionNumbers& dnums = args.dnums; // Dimension order: batch, feature, x0, x1 args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4}); dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); dnums.add_input_spatial_dimensions(2); dnums.add_output_spatial_dimensions(2); dnums.add_input_spatial_dimensions(3); dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3}); dnums.set_kernel_input_feature_dimension(2); dnums.set_kernel_output_feature_dimension(1); dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(0); auto dim0 = args.window.add_dimensions(); auto dim1 = args.window.add_dimensions(); dim0->set_size(3); dim0->set_stride(2); dim0->set_padding_low(1); dim0->set_padding_high(1); dim0->set_window_dilation(1); dim0->set_base_dilation(1); dim1->set_size(2); dim1->set_stride(1); dim1->set_padding_low(0); dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); return args; } TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) { ConvolveArgs args = MakeConvolveArgs(BF16, F16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/std::nullopt)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) { ConvolveArgs args = MakeConvolveArgs(F16, BF16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/std::nullopt)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) { ConvolveArgs args = MakeConvolveArgs(S32, U32); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/std::nullopt)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) { ConvolveArgs args = MakeConvolveArgs(U32, S32); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/std::nullopt)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/S16)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/S32)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, FloatingPointConvolveWithNarrowerPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(F32, F32); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/BF16)) ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, FloatingPointConvolveWithIntegralPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(BF16, BF16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/S32)); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, IntegralConvolveWithFloatingPointPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/F32)); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeWithDifferentSignedness) { ConvolveArgs args = MakeConvolveArgs(S8, S16); TF_ASSERT_OK_AND_ASSIGN( Shape inferred_shape, ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/U32)); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {10, 12, 2, 3}), inferred_shape)); } TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) { ConvolveArgs args = MakeConvolveArgs(S8, S16); auto inferred_status = ShapeInference::InferConvolveShape( args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1, args.window, args.dnums, /*preferred_element_type=*/S8) .status(); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.error_message(), HasSubstr("must not be narrower than the original type")); } namespace fft { static const char* unsupported_rank = "only supports ranks 1-3"; static const char* invalid_rank = "requires input of at least same rank"; static const char* requires_complex_input = "requires complex input type"; static const char* requires_f32_input = "requires F32 or F64 input type"; static const char* dimensions_match = "innermost dimensions match fft_length"; static const char* innermost_dimension_matches = "innermost dimension matches fft_length/2+1"; static void Pass(const Shape& shape, FftType type, absl::Span length, const Shape& expected_shape) { auto inferred_status = ShapeInference::InferFftShape(shape, type, length); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape)); } static void Fail(const Shape& shape, FftType type, absl::Span length, absl::string_view message) { auto inferred_status = ShapeInference::InferFftShape(shape, type, length); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr(std::string(message))); } } // namespace fft TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) { FftType type = FftType::FFT; Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); fft::Fail(shape, type, {}, fft::unsupported_rank); fft::Pass(shape, type, {8}, shape); fft::Pass(shape, type, {16, 8}, shape); fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank); fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank); } TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) { FftType type = FftType::FFT; Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) { FftType type = FftType::IFFT; Shape shape = ShapeUtil::MakeShape(C64, {16, 8}); fft::Fail(shape, type, {}, fft::unsupported_rank); fft::Pass(shape, type, {8}, shape); fft::Pass(shape, type, {16, 8}, shape); fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank); fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank); } TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) { FftType type = FftType::IFFT; Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); fft::Pass(shape_c128, type, {16, 8}, shape_c128); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) { FftType type = FftType::RFFT; Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); fft::Fail(shape_in, type, {}, fft::unsupported_rank); fft::Pass(shape_in, type, {8}, shape_out); fft::Pass(shape_in, type, {16, 8}, shape_out); fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank); fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) { FftType type = FftType::RFFT; Shape shape = ShapeUtil::MakeShape(F32, {16, 8}); fft::Fail(shape, type, {4}, fft::dimensions_match); fft::Fail(shape, type, {16, 4}, fft::dimensions_match); fft::Fail(shape, type, {8, 8}, fft::dimensions_match); fft::Fail(shape, type, {8, 16}, fft::dimensions_match); Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0}); Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0}); fft::Pass(zero_shape_in, type, {0}, zero_shape_out); fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out); Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8}); Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9}); Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5}); fft::Pass(even_shape_in, type, {16, 8}, shape_out); fft::Pass(odd_shape_in, type, {16, 9}, shape_out); } TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) { FftType type = FftType::RFFT; Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8}); fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input); fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input); } TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) { FftType type = FftType::IRFFT; Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5}); Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8}); fft::Fail(shape_in, type, {}, fft::unsupported_rank); fft::Pass(shape_in, type, {8}, shape_out); fft::Pass(shape_in, type, {16, 8}, shape_out); fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank); fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank); } TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) { FftType type = FftType::IRFFT; Shape shape = ShapeUtil::MakeShape(C64, {16, 5}); fft::Fail(shape, type, {5}, fft::innermost_dimension_matches); fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches); fft::Fail(shape, type, {8, 8}, fft::dimensions_match); fft::Fail(shape, type, {8, 9}, fft::dimensions_match); Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0}); Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0}); fft::Pass(zero_shape_in, type, {0}, zero_shape_out); fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out); Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8}); Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9}); fft::Pass(shape, type, {16, 8}, even_shape_out); fft::Pass(shape, type, {16, 9}, odd_shape_out); } TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) { FftType type = FftType::IRFFT; Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8}); Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5}); Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8}); fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input); fft::Pass(shape_c128, type, {16, 8}, shape_f64_out); } TEST_F(ShapeInferenceTest, MapThatChangesElementType) { Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); Shape expected = ShapeUtil::MakeShape(S32, {20}); EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, Map) { auto inferred_status_r1f32 = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie())); // It's OK to provide a single argument, as long as the applied arity matches // (this degenerates to a Map). auto inferred_status_r1f32_one = ShapeInference::InferMapShape( {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32_one.status()); EXPECT_TRUE( ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie())); auto inferred_status_r2s32 = ShapeInference::InferMapShape( {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); EXPECT_IS_OK(inferred_status_r2s32.status()); EXPECT_TRUE( ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie())); auto no_args_error = ShapeInference::InferMapShape( {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); ASSERT_FALSE(no_args_error.ok()); ASSERT_THAT(no_args_error.status().error_message(), HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(args_diff_shapes_error.ok()); ASSERT_THAT(args_diff_shapes_error.status().error_message(), HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); ASSERT_FALSE(arity_error.ok()); ASSERT_THAT(arity_error.status().error_message(), HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0}); ASSERT_FALSE(output_shape_error.ok()); ASSERT_THAT(output_shape_error.status().error_message(), HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0}); ASSERT_FALSE(param_shape_error.ok()); ASSERT_THAT(param_shape_error.status().error_message(), HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0}); ASSERT_FALSE(param_element_type_error.ok()); ASSERT_THAT(param_element_type_error.status().error_message(), HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie())); auto inferred_status_error1 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), HasSubstr("parameter type has to match argument")); } TEST_F(ShapeInferenceTest, MapWithDifferentInputTypes) { Shape arg0 = ShapeUtil::MakeShape(F32, {20}); Shape arg1 = ShapeUtil::MakeShape(S32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_}, s32_); auto inferred_status = ShapeInference::InferMapShape({&arg0, &arg1}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); Shape expected = ShapeUtil::MakeShape(S32, {20}); EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); } TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) { ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}), /*dimensions_to_reduce=*/{0}); } TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) { ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}), ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{0}); } TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) { ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}), ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{1}); } TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) { ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}), ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{0, 1}); } TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) { ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{1, 2}); } TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) { ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{0, 2}); // Check that the order of dimensions_to_reduce doesn't matter. ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{2, 0}); } TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) { ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}), /*dimensions_to_reduce=*/{0, 1, 2}); } TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); auto inferred_status = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}), inferred_status.ValueOrDie())); } TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); std::vector args = {&f32_arg_shape, &s32_arg_shape}; std::vector inits = {&f32_, &s32_}; ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); std::vector window_dimensions = {1, 2, 4}; std::vector window_strides = {1, 1, 1}; std::vector> padding_values = MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides, Padding::kValid); TF_ASSERT_OK_AND_ASSIGN( Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding_values, {}, {})); auto inferred_status = ShapeInference::InferReduceWindowShape( absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n"; EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal( ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}), ShapeUtil::MakeShape(S32, {5, 2, 0})}), inferred_status.ValueOrDie())); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); auto inferred_status = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("must take 4 parameters, but takes 6 parameter(s)")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); auto inferred_status = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT( inferred_status.status().error_message(), HasSubstr( "parameter shape differs from the result shape: s32[] vs f32[]")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) { ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("must have at least 2 arguments, has 0")); } TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1}); std::vector args = {&f32_arg_shape, &s32_arg_shape}; std::vector inits = {&f32_, &s32_}; ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_})); std::vector window_dimensions = {1, 2, 4}; std::vector window_strides = {1, 1, 1}; std::vector> padding_values = MakePadding(f32_arg_shape.dimensions(), window_dimensions, window_strides, Padding::kValid); TF_ASSERT_OK_AND_ASSIGN( Window window, ShapeInference::InferWindowFromDimensions( window_dimensions, window_strides, padding_values, {}, {})); auto inferred_status = ShapeInference::InferReduceWindowShape( absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply); EXPECT_FALSE(inferred_status.status().ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("f32[] vs s32[]")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_); auto inferred_status = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT( inferred_status.status().error_message(), HasSubstr("must produce a tuple with 2 elements, but produces a scalar")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_})); auto inferred_status = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT( inferred_status.status().error_message(), HasSubstr("must produce a tuple with 2 elements, but has 3 elements")); } TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) { Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3}); ProgramShape to_apply = ShapeUtil::MakeProgramShape( {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_})); auto inferred_status = ShapeInference::InferReduceShape( {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("accumulator shape at index 0 differs from the " "init_value shape: s32[] vs f32[]")); } TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( {&arg_shape, &f32_}, /*dimensions_to_reduce=*/{3, 4}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("out-of-bounds dimension")); } TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_); Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("take 2 parameters")); } TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_); Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape({&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("0-th parameter shape differs")); } TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) { ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); auto inferred_status = ShapeInference::InferReduceShape( {&arg_shape, &f32_}, /*dimensions_to_reduce=*/{0, 0}, to_apply); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("Duplicate reduction dimension: 0")); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred)); } TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true}); auto inferred_status = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal( ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred)); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred)); } TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred)); } TEST_F(ShapeInferenceTest, InferInvalidStride) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1}); ASSERT_FALSE(inferred_status.ok()); ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, inferred_status.status().code()); } TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); auto inferred_status = ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1}); ASSERT_FALSE(inferred_status.ok()); ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, inferred_status.status().code()); } TEST_F(ShapeInferenceTest, InferSliceShapeRank1) { Shape vector_shape = ShapeUtil::MakeShape(F32, {17}); auto inferred_status = ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1}); ASSERT_TRUE(inferred_status.ok()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2}))); } TEST_F(ShapeInferenceTest, InferConstIndexShape) { Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); auto inferred0_status = ShapeInference::InferGetTupleElementShape(tuple_shape, 0); auto inferred1_status = ShapeInference::InferGetTupleElementShape(tuple_shape, 1); ASSERT_IS_OK(inferred0_status.status()); ASSERT_IS_OK(inferred1_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie())); ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) { Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_}); auto inferredNegative_status = ShapeInference::InferGetTupleElementShape(tuple_shape, -1); auto inferred2_status = ShapeInference::InferGetTupleElementShape(tuple_shape, 2); ASSERT_FALSE(inferredNegative_status.ok()); ASSERT_FALSE(inferred2_status.ok()); EXPECT_THAT(inferredNegative_status.status().error_message(), HasSubstr("attempt to index out of tuple bounds")); EXPECT_THAT(inferred2_status.status().error_message(), HasSubstr("attempt to index out of tuple bounds")); } TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = ShapeInference::InferBinaryOpShape( HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, InferCompareShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = ShapeInference::InferBinaryOpShape( HloOpcode::kCompare, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) { // [1, <=1] // | reshape // [<=1] // // Both output dimension can be dynamic, use inferred_dimension to tie-break. auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true}); auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1}, /*inferred_dimension=*/-1); ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie()); } TEST_F(ShapeInferenceTest, InferReshapeSplit) { // [<=10] // | reshape // [1, 10] // // Both output dimension can be dynamic, use inferred_dimension to tie-break. auto operand = ShapeUtil::MakeShape(F32, {10}, {true}); auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10}, /*inferred_dimension=*/0); ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}), status.ValueOrDie()); } TEST_F(ShapeInferenceTest, InferReshapeCombine) { // [6, <=10] // | reshape // [<=60] auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true}); auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60}, /*inferred_dimension=*/-11); ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie()); } TEST_F(ShapeInferenceTest, UnchangedDimension) { // [6, <=10] // | reshape // [2, 3, <=10] auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true}); auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10}, /*inferred_dimension=*/-11); ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}), status.ValueOrDie()); } TEST_F(ShapeInferenceTest, InferDynamicBroadcast) { // CHECK: // %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1} auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true}); auto inferred_status = ShapeInference::InferBroadcastShape(operand_shape, {15}); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred); } TEST_F(ShapeInferenceTest, BroadcastScalar) { for (auto element_type : {F32, U32, S8}) { const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {}); { // no-op scalar broadcast auto status = ShapeInference::InferBroadcastShape(scalar_shape, {}); ASSERT_IS_OK(status.status()); ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie())); } const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3}); { // scalar -> 1d broadcast auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3}); ASSERT_IS_OK(status.status()); ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie())); } { // no-op 1d broadcast auto status = ShapeInference::InferBroadcastShape(oned_shape, {}); ASSERT_IS_OK(status.status()); ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie())); } const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3}); { // scalar -> 2d broadcast auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3}); ASSERT_IS_OK(status.status()); ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie())); } { // 1d -> 2d broadcast auto status = ShapeInference::InferBroadcastShape(oned_shape, {2}); ASSERT_IS_OK(status.status()); ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie())); } } } // scalar vector: ok TEST_F(ShapeInferenceTest, ScalarDotVector) { DotDimensionNumbers dot_dnums; auto inferred_status = ShapeInference::InferDotOpShape( f32_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); EXPECT_TRUE(inferred_status.ok()); EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_); } // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); EXPECT_TRUE(inferred_status.ok()); EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), ShapeUtil::MakeShape(F32, {32, 32, 64}))); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); auto inferred_status_mismatch = ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); auto inferred_status_mismatch = ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status_match = ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); auto inferred_status_mismatch = ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status_mismatch.ok()); } // BatchMatMul with two batch dimensions and one contracting dimension. TEST_F(ShapeInferenceTest, DotGeneral) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(3); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_lhs_batch_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(2); dot_dnums.add_rhs_batch_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status_match = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(output_shape); } // BatchMatMul with two contracting dimensions fails. TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); dot_dnums.add_lhs_contracting_dimensions(3); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_batch_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Must specify the same number of contracting " "dimensions for lhs and rhs.")); } TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14}); Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); dot_dnums.add_lhs_contracting_dimensions(3); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(2); dot_dnums.add_rhs_batch_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); EXPECT_TRUE(inferred_status.ok()); EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) { Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape val_shape = ShapeUtil::MakeShape(S32, {1}); auto inferred_status = ShapeInference::InferSetDimensionSizeShape( arg_shape, val_shape, /*dimension=*/0); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("value has to be S32 scalar")); } TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) { Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); Shape val_shape = ShapeUtil::MakeShape(U32, {}); auto inferred_status = ShapeInference::InferSetDimensionSizeShape( arg_shape, val_shape, /*dimension=*/0); EXPECT_FALSE(inferred_status.ok()); EXPECT_THAT(inferred_status.status().error_message(), HasSubstr("value has to be S32 scalar")); } // BatchMatMul with different batch dimension sizes fails. TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(1); dot_dnums.add_rhs_batch_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("Batch dimension sizes must match")); } // BatchMatMul with different batch dimension numbers passes TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(2); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_TRUE(inferred_status.ok()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), ShapeUtil::MakeShape(F32, {2, 11, 14}))); } // BatchMatMul with out-of-range dimension numbers fails. TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(3); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("A dimension number is out of range")); } // BatchMatMul with non-unique dimension numbers fails. TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_lhs_batch_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); auto inferred_status = ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums, /*preferred_element_type=*/std::nullopt); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("A dimension number is not unique")); } TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, /*preferred_element_type=*/S32)); EXPECT_TRUE( ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32}))); } TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(BF16, {32, 32}), ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, /*preferred_element_type=*/F32)); EXPECT_TRUE( ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32}))); } TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(BF16, {32, 32}), ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums, /*preferred_element_type=*/BF16)); EXPECT_TRUE( ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32}))); } TEST_F(ShapeInferenceTest, FloatingPointDotWithIntegralPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(BF16, {32, 32}), ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums, /*preferred_element_type=*/S32)); EXPECT_TRUE( ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32}))); } TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, /*preferred_element_type=*/F32)); EXPECT_TRUE( ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32}))); } TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape, ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, /*preferred_element_type=*/U32)); EXPECT_TRUE( ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(U32, {32, 32}))); } TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = ShapeInference::InferDotOpShape( ShapeUtil::MakeShape(S8, {32, 32}), ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums, /*preferred_element_type=*/S8) .status(); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.error_message(), HasSubstr("must not be narrower than the original type")); } TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. const Shape mat = ShapeUtil::MakeShape(F32, {16, 8}); const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); auto inferred_status_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); inferred_status_match = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); inferred_status_mismatch = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { // Test variations of broadcasting a matrix for a binary add with a cube. const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4}); const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4}); const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4}); const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { // Test various errors with the broadcast argument. const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4}); const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8}); const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4}); const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Shapes must be equal rank")); // broadcast_dimension out of bounds for tensor's rank auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("dimension 0 mismatch")); // The following two tests make sure that broadcasting dimensions are listed // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), HasSubstr("dimensions order is wrong")); } // Tests for the while instruction with proper shapes. TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) { Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); auto inferred_status = ShapeInference::InferWhileShape(cond, body, result_shape); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred)); } // Tests for the while instruction with wrong shapes. TEST_F(ShapeInferenceTest, WhileWithBadShapes) { Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_}); ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_); ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape); auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_); auto inferred_status_error1 = ShapeInference::InferWhileShape(bad_shape_1, body, result_shape); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Condition must take 1 arguments")); auto bad_shape_2 = ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape); auto inferred_status_error2 = ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("Body must take 1 arguments")); auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_); auto inferred_status_error3 = ShapeInference::InferWhileShape(bad_shape_3, body, result_shape); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Condition must return a boolean")); auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_); auto inferred_status_error4 = ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("parameter of condition and body")); } // Tests for the concatenate instruction with dynamic shapes. TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) { auto dynamic_shape_1 = ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false}); auto dynamic_shape_2 = ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false}); auto inferred_status = ShapeInference::InferConcatOpShape( {&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0); ASSERT_IS_OK(inferred_status.status()); Shape inferred = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal( ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred)); } // Tests for the concatenate instruction with proper shapes. TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) { auto inferred_status_1 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_64_}, /*dimension=*/0); ASSERT_IS_OK(inferred_status_1.status()); Shape inferred_1 = inferred_status_1.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1)); auto inferred_status_2 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0); ASSERT_IS_OK(inferred_status_2.status()); Shape inferred_2 = inferred_status_2.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2)); auto inferred_status_3 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1); ASSERT_IS_OK(inferred_status_3.status()); Shape inferred_3 = inferred_status_3.ValueOrDie(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3)); } // Tests for the concatenate instruction with wrong shapes. TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { auto inferred_status_error1 = ShapeInference::InferConcatOpShape({}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Concatenate expects at least one argument")); auto inferred_status_error2 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("dimension out of bounds: -1")); auto inferred_status_error3 = ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("dimension out of bounds: 1")); Shape tuple = ShapeUtil::MakeTupleShape({vector_32_}); auto inferred_status_error4 = ShapeInference::InferConcatOpShape( {&vector_32_, &tuple}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( {&vector_32_, &vector_s32}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), HasSubstr("concatenate arrays with different element types")); auto inferred_status_error6 = ShapeInference::InferConcatOpShape( {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("concatenate arrays that differ in " "dimensions other than the one being " "concatenated")); } TEST_F(ShapeInferenceTest, Pad) { Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); Shape padding_value_shape = ShapeUtil::MakeShape(F32, {}); // Padding for dimension 0: {low: 0, high: 2, interior: 3} // Padding for dimension 1: {low: 1, high: 5, interior: 0} PaddingConfig padding_config; auto dimension0 = padding_config.add_dimensions(); dimension0->set_edge_padding_low(0); dimension0->set_edge_padding_high(2); dimension0->set_interior_padding(3); auto dimension1 = padding_config.add_dimensions(); dimension1->set_edge_padding_low(1); dimension1->set_edge_padding_high(5); dimension1->set_interior_padding(0); auto inferred_status = ShapeInference::InferPadShape( input_shape, padding_value_shape, padding_config); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape)); dimension1->set_edge_padding_low(-20); dimension1->set_edge_padding_high(-10); auto negative_dimension_size = ShapeInference::InferPadShape( input_shape, padding_value_shape, padding_config); ASSERT_FALSE(negative_dimension_size.ok()); ASSERT_THAT(negative_dimension_size.status().error_message(), HasSubstr("negative size for dimension 1")); } TEST_F(ShapeInferenceTest, Reverse) { Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1}); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape)); } TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25}); auto inferred_status_error0 = ShapeInference::InferReverseShape(input_shape, {0, 2}); ASSERT_FALSE(inferred_status_error0.ok()); ASSERT_THAT(inferred_status_error0.status().error_message(), HasSubstr("out-of-bounds")); auto inferred_status_error1 = ShapeInference::InferReverseShape(input_shape, {0, -1}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("out-of-bounds")); auto inferred_status_error2 = ShapeInference::InferReverseShape(input_shape, {0, 0}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("duplicated")); Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape}); auto inferred_status_error3 = ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { auto inferred_status0 = ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_)); EXPECT_IS_OK(inferred_status0.status()); EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); auto inferred_status1 = ShapeInference::InferCallShape( {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_}, ShapeUtil::MakeProgramShape( {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_)); EXPECT_IS_OK(inferred_status1.status()); EXPECT_TRUE( ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie())); auto inferred_status_error0 = ShapeInference::InferCallShape( {}, ShapeUtil::MakeProgramShape({f32_}, f32_)); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), HasSubstr("arity must match")); auto inferred_status_error1 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({}, f32_)); EXPECT_FALSE(inferred_status_error1.ok()); EXPECT_THAT(inferred_status_error1.status().error_message(), HasSubstr("arity must match")); auto inferred_status_error2 = ShapeInference::InferCallShape( {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_)); EXPECT_FALSE(inferred_status_error2.ok()); EXPECT_THAT(inferred_status_error2.status().error_message(), HasSubstr("parameter must match argument")); } TEST_F(ShapeInferenceTest, Transpose) { Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5}); auto inferred_shape_and_status = ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0}); EXPECT_IS_OK(inferred_shape_and_status); Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } TEST_F(ShapeInferenceTest, Rank1Transpose) { Shape a_shape = ShapeUtil::MakeShape(F32, {5}); auto inferred_shape_and_status = ShapeInference::InferTransposeShape(a_shape, {0}); EXPECT_IS_OK(inferred_shape_and_status); Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); EXPECT_TRUE( ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); } TEST_F(ShapeInferenceTest, ConditionalPred) { auto inferred_status0 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_}); EXPECT_IS_OK(inferred_status0.status()); EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); auto inferred_status1 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)}, {matrix_32_48_, vector_32_}); EXPECT_IS_OK(inferred_status1.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); auto inferred_status2 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, {matrix_32_48_, tuple_f32_v32}); EXPECT_IS_OK(inferred_status2.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); auto inferred_status_error0 = ShapeInference::InferConditionalShape( f32_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), HasSubstr("must be bool or int32_t")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); EXPECT_FALSE(inferred_status_error1.ok()); EXPECT_THAT(inferred_status_error1.status().error_message(), HasSubstr("branch computation 0 must take 1 argument")); auto inferred_status_error2 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_64_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error2.ok()); EXPECT_THAT(inferred_status_error2.status().error_message(), HasSubstr("branch operand 0 must match the shape of the only " "parameter of branch computation 0")); auto inferred_status_error3 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)}, {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})}); EXPECT_FALSE(inferred_status_error3.ok()); EXPECT_THAT(inferred_status_error3.status().error_message(), HasSubstr("branch computation 1 must take 1 argument")); auto inferred_status_error4 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error4.ok()); EXPECT_THAT(inferred_status_error4.status().error_message(), HasSubstr("branch operand 1 must match the shape of the only " "parameter of branch computation 1")); auto inferred_status_error5 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, {vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error5.ok()); EXPECT_THAT(inferred_status_error5.status().error_message(), HasSubstr("the result of branch 0 computation and branch 1 " "computation must have the same shape")); } TEST_F(ShapeInferenceTest, ConditionalIndexed) { auto r0s32 = ShapeUtil::MakeShape(S32, {}); auto inferred_status0 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_64_, vector_64_}); EXPECT_IS_OK(inferred_status0.status()); EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); auto inferred_status1 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), ShapeUtil::MakeProgramShape({vector_32_}, vector_64_), ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)}, {matrix_32_48_, vector_32_, matrix_32_48_}); EXPECT_IS_OK(inferred_status1.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); auto inferred_status2 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)}, {tuple_f32_v32}); EXPECT_IS_OK(inferred_status2.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); auto inferred_status_error0 = ShapeInference::InferConditionalShape( pred_, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, f32_)}, {vector_32_, vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error0.ok()); EXPECT_THAT(inferred_status_error0.status().error_message(), HasSubstr("2 == branch_computations.size()")); auto inferred_status_error1 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)}, {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_}); EXPECT_FALSE(inferred_status_error1.ok()); EXPECT_THAT(inferred_status_error1.status().error_message(), HasSubstr("branch computation 1 must take 1 argument")); auto inferred_status_error2 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({r0s32}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_)}, {r0s32, vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error2.ok()); EXPECT_THAT(inferred_status_error2.status().error_message(), HasSubstr("branch operand 2 must match the shape of the only " "parameter of branch computation 2")); auto inferred_status_error3 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_32_}, f32_), ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)}, {vector_32_, vector_32_, vector_32_, vector_64_}); EXPECT_FALSE(inferred_status_error3.ok()); EXPECT_THAT(inferred_status_error3.status().error_message(), HasSubstr("the result of branch 0 computation and branch 3 " "computation must have the same shape")); auto inferred_status_error4 = ShapeInference::InferConditionalShape(r0s32, {}, {}); EXPECT_FALSE(inferred_status_error4.ok()); EXPECT_THAT(inferred_status_error4.status().error_message(), HasSubstr("!branch_computations.empty()")); } TEST_F(ShapeInferenceTest, ConditionalDynamic) { auto r0s32 = ShapeUtil::MakeShape(S32, {}); auto static_shape = ShapeUtil::MakeShape(S32, {4}, {false}); auto dynamic_shape = ShapeUtil::MakeShape(S32, {4}, {true}); auto inferred_status0 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, static_shape), ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape), ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, {vector_32_, vector_64_, vector_64_}); EXPECT_IS_OK(inferred_status0.status()); EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status0.ValueOrDie())); auto inferred_status1 = ShapeInference::InferConditionalShape( r0s32, {ShapeUtil::MakeProgramShape({vector_32_}, dynamic_shape), ShapeUtil::MakeProgramShape({vector_64_}, static_shape), ShapeUtil::MakeProgramShape({vector_64_}, dynamic_shape)}, {vector_32_, vector_64_, vector_64_}); EXPECT_IS_OK(inferred_status1.status()); EXPECT_TRUE(ShapeUtil::Equal(dynamic_shape, inferred_status1.ValueOrDie())); } TEST_F(ShapeInferenceTest, BadSlice) { auto arg = ShapeUtil::MakeShape(F32, {4}); StatusOr statusor = ShapeInference::InferSliceShape(arg, {0}, {5}, {1}); ASSERT_FALSE(statusor.ok()); LOG(INFO) << statusor.status(); EXPECT_THAT(statusor.status().error_message(), HasSubstr("less than or equal to dimension size")) << statusor.status(); EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape")) << statusor.status(); } TEST_F(ShapeInferenceTest, BadSort) { auto keys = ShapeUtil::MakeShape(F32, {4}); auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { auto keys = ShapeUtil::MakeShape(F32, {4}); auto values_good = ShapeUtil::MakeShape(F32, {4}); auto values_bad = ShapeUtil::MakeShape(F32, {5}); StatusOr statusor = ShapeInference::InferVariadicOpShape( HloOpcode::kSort, {&keys, &values_good, &values_bad}); EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } TEST_F(ShapeInferenceTest, SortManyValues) { auto keys = ShapeUtil::MakeShape(F32, {4}); auto values_s32 = ShapeUtil::MakeShape(S32, {4}); auto values_u32 = ShapeUtil::MakeShape(U32, {4}); StatusOr statusor = ShapeInference::InferVariadicOpShape( HloOpcode::kSort, {&keys, &values_s32, &values_u32}); EXPECT_IS_OK(statusor); Shape inferred_shape = statusor.ValueOrDie(); EXPECT_TRUE(ShapeUtil::Compatible( inferred_shape, ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); } class GatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); const Shape s64_4d_tensor_10_9_8_7_1_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); const Shape s64_4d_tensor_10_9_8_7_5_ = ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); const Shape s64_4d_tensor_5_10_9_7_6_ = ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6}); const Shape s64_4d_tensor_10_9_5_7_6_ = ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); const Shape f32_5d_tensor_50_49_48_47_46_ = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); }; TEST_F(GatherShapeInferenceTest, TensorFlowGather) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, /*collapsed_slice_dims=*/{1}, /*start_index_map=*/{1}, /*index_vector_dim=*/1), /*slice_sizes=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{1}, /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0}, /*index_vector_dim=*/1), /*slice_sizes=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4}, /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0}, /*index_vector_dim=*/4), /*slice_sizes=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, DynamicGatherEntireDimension) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}), ShapeUtil::MakeShape(S64, {}), HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0, 1}, /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0}, /*index_vector_dim=*/0), /*slice_sizes=*/{1, 2, 1})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, DynamicGatherCollapsedDimension) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}), ShapeUtil::MakeShape(S64, {}), HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0, 1}, /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0}, /*index_vector_dim=*/0), /*slice_sizes=*/{1, 2, 1})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, DynamicIndices) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( ShapeUtil::MakeShape(F32, {3, 2, 2}), ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}), HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{2, 3}, /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0, 1}, /*index_vector_dim=*/2), /*slice_sizes=*/{1, 2, 2})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal( gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0, 1, 2, 3, 4}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0), /*slice_sizes=*/{30, 29, 28, 27, 26})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0, 1, 2, 3}, /*collapsed_slice_dims=*/{0}, /*start_index_map=*/{0}, /*index_vector_dim=*/0), /*slice_sizes=*/{1, 30, 29, 28, 27})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, /*collapsed_slice_dims=*/{1}, /*start_index_map=*/{1}, /*index_vector_dim=*/1), /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for input")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, /*collapsed_slice_dims=*/{1}, /*start_index_map=*/{1}, /*index_vector_dim=*/0), /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for gather indices")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{0}, /*collapsed_slice_dims=*/{1}, /*start_index_map=*/{1}, /*index_vector_dim=*/0), /*slice_sizes=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather indices parameter must be an integral tensor")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 8, 7}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Output window dimensions in gather op must be ascending")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 7}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Output window dimensions in gather op must not repeat")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 99, 100, 101}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Offset dimension 2 in gather op is out of bounds")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 9}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Offset dimension 4 in gather op is out of bounds")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{4}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("All components of the offset index in a gather op must either " "be a offset dimension or explicitly collapsed")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{0, 1, 2, 3, 19}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid collapsed_slice_dims set in gather op; valid " "range is [0, 5), got: 19")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{0, 1, 2, 3, 3}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Repeated dimensions not allowed in " "collapsed_slice_dims in gather op")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather op has 4 elements in start_index_map and " "the bound of dimension index_vector_dim=4 of " "start_indices is 5. These two numbers must be equal.")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 7}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 3}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Repeated dimensions are not allowed in start_index_map")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{2, 1}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{1, 1, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("collapsed_slice_dims in gather op must be sorted")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7}, /*collapsed_slice_dims=*/{2}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 1, 300, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Slice size at index 3 in gather op is out of range, " "must be within [0, 48), got 300.")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Gather op must have one slice size for every input dimension")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7}, /*collapsed_slice_dims=*/{1}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), /*slice_sizes=*/{30, 29, 28, 26, 20}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, " "but bound is 29 for index 1 at position 0.")) << statusor.status(); } TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( /*offset_dims=*/{4, 5, 6, 7, 8}, /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/32), /*slice_sizes=*/{30, 29, 28, 27, 26}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Gather index leaf dimension must be within [0, " "rank(start_indices) + 1)")) << statusor.status(); } class ScatterShapeInferenceTest : public ShapeInferenceTest, public ::testing::WithParamInterface> { protected: struct ScatterShapes { void Add(Shape shape) { shapes.push_back(std::move(shape)); ptrs.push_back(&shapes.back()); } std::vector shapes; std::vector ptrs; }; static ScatterShapes CreateShapes(absl::Span operand_dims, const Shape& scatter_indices_shape, absl::Span update_dims, absl::Span types) { CHECK(!types.empty()); size_t size = types.size() * 2 + 1; ScatterShapes shapes; shapes.shapes.reserve(size); shapes.ptrs.reserve(size); for (PrimitiveType type : types) { shapes.Add(ShapeUtil::MakeShape(type, operand_dims)); } shapes.Add(scatter_indices_shape); for (PrimitiveType type : types) { shapes.Add(ShapeUtil::MakeShape(type, update_dims)); } return shapes; } static Shape Collate(absl::Span dims, absl::Span types) { CHECK(!types.empty()); if (types.size() == 1) { return ShapeUtil::MakeShape(types[0], dims); } std::vector shapes; for (PrimitiveType type : types) { shapes.push_back(ShapeUtil::MakeShape(type, dims)); } return ShapeUtil::MakeTupleShape(shapes); } static Shape scalar(PrimitiveType type) { return ShapeUtil::MakeShape(type, {}); } static Shape s64_vector(int dim) { return ShapeUtil::MakeShape(S64, {dim}); } static Shape s64_tensor(absl::Span dims) { return ShapeUtil::MakeShape(S64, dims); } static ProgramShape to_apply(absl::Span types) { CHECK(!types.empty()); ProgramShape program_shape; Shape& result = *program_shape.mutable_result(); result = ShapeUtil::MakeNil(); result.mutable_tuple_shapes()->reserve(types.size()); program_shape.mutable_parameters()->reserve(types.size() * 2); for (PrimitiveType type : types) { *program_shape.add_parameters() = scalar(type); *result.add_tuple_shapes() = scalar(type); } for (PrimitiveType type : types) { *program_shape.add_parameters() = scalar(type); } return program_shape; } std::vector types() const { return GetParam(); } }; TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdates) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 32}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterWithFullUpdatesV2) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 48}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/1))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdates) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {10, 32}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterWithPartialUpdatesV2) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 8}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/1))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {65, 32}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Bounds of the window dimensions of updates must not exceed " "the bounds of the corresponding dimensions of operand.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {32, 49}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Bounds of the window dimensions of updates must not exceed " "the bounds of the corresponding dimensions of operand.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndicesV2) { auto shapes = CreateShapes({64, 48}, s64_vector(32), {31, 48}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{1}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdates) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {10, 9, 8, 7, 48}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/4))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {10, 9, 8, 7, 64}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/4))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdates) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {10, 9, 8, 7, 10}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/4))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {10, 9, 8, 7, 12}, types()); TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/4))); EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, Collate({64, 48}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {10, 9, 8, 7, 65}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Bounds of the window dimensions of updates must not exceed " "the bounds of the corresponding dimensions of operand.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, TfScatterNdWithUpdatesNotMatchingIndices) { auto shapes = CreateShapes({64, 48}, s64_tensor({10, 9, 8, 7, 1}), {9, 9, 8, 7, 64}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, TfBatchDynamicUpdateSlice) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); TF_ASSERT_OK_AND_ASSIGN( Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 8}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4))); EXPECT_TRUE( ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 5, 7, 6}), {10, 9, 7, 6, 30, 29, 28, 27, 26}, types()); TF_ASSERT_OK_AND_ASSIGN( Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 8}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2))); EXPECT_TRUE( ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({5, 10, 9, 7, 6}), {10, 9, 7, 6, 30, 29, 28, 27, 26}, types()); TF_ASSERT_OK_AND_ASSIGN( Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 8}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0))); EXPECT_TRUE( ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, NoUpdateScatterDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_vector(5), {30, 29, 28, 27, 26}, types()); // This is equivalent to a dynamic update slice. TF_ASSERT_OK_AND_ASSIGN( Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0, 1, 2, 3, 4}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/0))); EXPECT_TRUE( ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, ScalarScatterIndices) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), {30, 29, 28, 27}, types()); // The scalar indices "tensor" is a scalar S here that's used to update a // [30,29,28,27] shaped tensor within the operand at position S. TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0, 1, 2, 3}, /*inserted_window_dims=*/{0}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/0))); EXPECT_TRUE( ShapeUtil::Equal(scatter_shape, Collate({50, 49, 48, 47, 46}, types()))) << ShapeUtil::HumanString(scatter_shape); } TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedTensorInput) { Shape tuple_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); Shape s64_vector_32 = s64_vector(32); StatusOr statusor = ShapeInference::InferScatterShape( {&tuple_shape, &s64_vector_32, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/1)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for operand")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedScatterIndicesInput) { Shape tuple_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); Shape s64_vector_32 = s64_vector(32); StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &tuple_shape, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/0)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for scatter indices")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { Shape tuple_shape = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}), ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1})}); Shape s64_vector_32 = s64_vector(32); StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &s64_vector_32, &tuple_shape}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/0)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Expected array argument for updates")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, FloatingPointScatterIndicesInput) { Shape s64_vector_32 = s64_vector(32); StatusOr statusor = ShapeInference::InferScatterShape( {&s64_vector_32, &vector_32_, &s64_vector_32}, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0}, /*inserted_window_dims=*/{1}, /*scatter_dims_to_operand_dims=*/{1}, /*index_vector_dim=*/0)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Scatter indices parameter must be an integral tensor")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 2}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/10)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Scatter index leaf dimension must be within [0, " "rank(scatter_indices) + 1)")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidUpdates) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 50}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 2}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Updates tensor must be of rank 7; got 8.")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidUpdateComputation) { const ProgramShape invalid_update_computation = ShapeUtil::MakeProgramShape({f32_}, f32_); auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, invalid_update_computation, HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 2}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr(absl::Substitute( "Reduction function must take $0 parameters, but takes 1", 2 * types().size()))) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 8, 7}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("update_window_dims in scatter op must be sorted")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 7}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("update_window_dims in scatter op must not repeat")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28, 27, 26}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6, 7, 9}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid update_window_dims set in scatter op; valid " "range is [0, 9)")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{2, 1}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("inserted_window_dims in scatter op must be sorted")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 1}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("inserted_window_dims in scatter op must not repeat")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 5}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid inserted_window_dims set in scatter op; valid " "range is [0, 5)")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 2}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and " "the bound of dimension index_vector_dim=4 of scatter_indices " "is 5. These two numbers must be equal")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 2}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain " "is [0, 5), got: 4->10")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, s64_tensor({10, 9, 8, 7, 5}), {10, 9, 8, 7, 30, 29, 28}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{4, 5, 6}, /*inserted_window_dims=*/{1, 2}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3}, /*index_vector_dim=*/4)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr( "Repeated dimensions not allowed in scatter_dims_to_operand_dims")) << statusor.status(); } TEST_P(ScatterShapeInferenceTest, InvalidScatterDimNumbers_InsufficientWindowDims) { auto shapes = CreateShapes({50, 49, 48, 47, 46}, scalar(S64), {30, 29, 28, 27}, types()); StatusOr statusor = ShapeInference::InferScatterShape( shapes.ptrs, to_apply(types()), HloScatterInstruction::MakeScatterDimNumbers( /*update_window_dims=*/{0, 1, 2, 3}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0}, /*index_vector_dim=*/0)); ASSERT_FALSE(statusor.ok()); EXPECT_THAT( statusor.status().error_message(), HasSubstr( "Scatter op has window of size 4; doesn't match operand of rank 5.")) << statusor.status(); } struct ScatterTestName { std::string operator()( const ::testing::TestParamInfo>& info) const { return absl::StrJoin(info.param, "_", absl::StreamFormatter()); } }; INSTANTIATE_TEST_SUITE_P(All, ScatterShapeInferenceTest, ::testing::Values(std::vector{F32}, std::vector{F32, BF16}), ScatterTestName()); } // namespace } // namespace xla