1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s 2 3// Folding this case would explode the IR 4func.func @scatter_fold_explosion() -> tensor<512x1x6400x6400xf32> { 5 %base = mhlo.constant dense<0.000000e+00> : tensor<512x1x6400x6400xf32> 6 %index = mhlo.constant dense<1> : tensor<1xi32> 7 %update = mhlo.constant dense<1.000000e+00> : tensor<511x1x6400x6400xf32> 8 // CHECK: mhlo.scatter 9 %scatter = "mhlo.scatter"(%base, %index, %update) ({ 10 ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): 11 "mhlo.return"(%arg6) : (tensor<f32>) -> () 12 }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [3]>, unique_indices = true} : (tensor<512x1x6400x6400xf32>, tensor<1xi32>, tensor<511x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32> 13 14 func.return %scatter : tensor<512x1x6400x6400xf32> 15} 16 17// ----- 18 19// Verify that a full overwrite of the "base" with a scatter is not folded 20// if the type mismatch. 21// TODO(mhlo): this would be nice to handle: the update could be broadcasted 22// to the type of the base here. 23func.func @scatter_full_overwrite_type_mismatch(%base : tensor<1x1x1xf64>) -> tensor<1x1x1xf64> { 24 %0 = mhlo.constant dense<0.28209479177387814> : tensor<1xf64> 25 %1 = mhlo.constant dense<0> : tensor<2xi32> 26 %scatter = "mhlo.scatter"(%base, %1, %0) ({ 27 ^bb0(%arg11: tensor<f64>, %arg12: tensor<f64>): 28 "mhlo.return"(%arg12) : (tensor<f64>) -> () 29 }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1]>, unique_indices = true} : (tensor<1x1x1xf64>, tensor<2xi32>, tensor<1xf64>) -> tensor<1x1x1xf64> 30 31 // CHECK: %[[SCATTER:.*]] = "mhlo.scatter 32 // CHECK: return %[[SCATTER]] 33 func.return %scatter : tensor<1x1x1xf64> 34} 35 36// ----- 37 38// Verify that a full overwrite of the "base" with a scatter is correctly folded 39// even if the tensor is huge. 40func.func @scatter_full_overwrite() -> tensor<512x1x6400x6400xf32> { 41 %base = mhlo.constant dense<0.000000e+00> : tensor<512x1x6400x6400xf32> 42 %index = mhlo.constant dense<0> : tensor<1xi32> 43 %update = mhlo.constant dense<1.000000e+00> : tensor<512x1x6400x6400xf32> 44 %scatter = "mhlo.scatter"(%base, %index, %update) ({ 45 ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): 46 "mhlo.return"(%arg6) : (tensor<f32>) -> () 47 }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [3]>, unique_indices = true} : (tensor<512x1x6400x6400xf32>, tensor<1xi32>, tensor<512x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32> 48 49 // CHECK: %[[FOLD:.*]] = mhlo.constant dense<1.000000e+00> : tensor<512x1x6400x6400xf32> 50 // CHECK: return %[[FOLD]] 51 func.return %scatter : tensor<512x1x6400x6400xf32> 52} 53 54// ----- 55 56// Verify that a full overwrite of the "base" with a scatter is correctly folded 57// even if the base and update are not constant values. 58func.func @scatter_full_overwrite_non_const( 59 %base : tensor<512x1x6400x6400xf32>, 60 %update : tensor<512x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32> { 61 %index = mhlo.constant dense<0> : tensor<1xi32> 62 %scatter = "mhlo.scatter"(%base, %index, %update) ({ 63 ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>): 64 "mhlo.return"(%arg6) : (tensor<f32>) -> () 65 }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [3]>, unique_indices = true} : (tensor<512x1x6400x6400xf32>, tensor<1xi32>, tensor<512x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32> 66 67 // CHECK: return %arg1 68 func.return %scatter : tensor<512x1x6400x6400xf32> 69} 70 71// ----- 72 73// Verify that a full overwrite of the "base" with a scatter is not folded when 74// there is a non-identity computation. 75func.func public @scatter_non_identity(%arg0: tensor<12xbf16>, %arg1: tensor<12xbf16>) -> tensor<12xbf16> { 76 %0 = mhlo.constant dense<0> : tensor<1xi32> 77 %1 = "mhlo.scatter"(%arg0, %0, %arg1) ({ 78 ^bb0(%arg2: tensor<bf16>, %arg3: tensor<bf16>): 79 %2 = mhlo.add %arg2, %arg3 : tensor<bf16> 80 "mhlo.return"(%2) : (tensor<bf16>) -> () 81 }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = true} : (tensor<12xbf16>, tensor<1xi32>, tensor<12xbf16>) -> tensor<12xbf16> 82 // CHECK: %[[SCATTER:.*]] = "mhlo.scatter 83 // CHECK: return %[[SCATTER]] 84 func.return %1 : tensor<12xbf16> 85} 86