• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func.func(canonicalize)' | FileCheck %s
2
3// Folding this case would explode the IR
4func.func @scatter_fold_explosion() ->  tensor<512x1x6400x6400xf32> {
5  %base = mhlo.constant dense<0.000000e+00> : tensor<512x1x6400x6400xf32>
6  %index = mhlo.constant dense<1> : tensor<1xi32>
7  %update = mhlo.constant dense<1.000000e+00> : tensor<511x1x6400x6400xf32>
8  // CHECK: mhlo.scatter
9  %scatter = "mhlo.scatter"(%base, %index, %update) ({
10    ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):
11      "mhlo.return"(%arg6) : (tensor<f32>) -> ()
12  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [3]>, unique_indices = true} : (tensor<512x1x6400x6400xf32>, tensor<1xi32>, tensor<511x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32>
13
14  func.return %scatter :  tensor<512x1x6400x6400xf32>
15}
16
17// -----
18
19// Verify that a full overwrite of the "base" with a scatter is not folded
20// if the type mismatch.
21// TODO(mhlo): this would be nice to handle: the update could be broadcasted
22// to the type of the base here.
23func.func @scatter_full_overwrite_type_mismatch(%base : tensor<1x1x1xf64>) ->  tensor<1x1x1xf64> {
24  %0 = mhlo.constant dense<0.28209479177387814> : tensor<1xf64>
25  %1 = mhlo.constant dense<0> : tensor<2xi32>
26  %scatter = "mhlo.scatter"(%base, %1, %0) ({
27  ^bb0(%arg11: tensor<f64>, %arg12: tensor<f64>):
28    "mhlo.return"(%arg12) : (tensor<f64>) -> ()
29  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1]>, unique_indices = true} : (tensor<1x1x1xf64>, tensor<2xi32>, tensor<1xf64>) -> tensor<1x1x1xf64>
30
31  // CHECK: %[[SCATTER:.*]] = "mhlo.scatter
32  // CHECK: return %[[SCATTER]]
33  func.return %scatter :  tensor<1x1x1xf64>
34}
35
36// -----
37
38// Verify that a full overwrite of the "base" with a scatter is correctly folded
39// even if the tensor is huge.
40func.func @scatter_full_overwrite() ->  tensor<512x1x6400x6400xf32> {
41  %base = mhlo.constant dense<0.000000e+00> : tensor<512x1x6400x6400xf32>
42  %index = mhlo.constant dense<0> : tensor<1xi32>
43  %update = mhlo.constant dense<1.000000e+00> : tensor<512x1x6400x6400xf32>
44  %scatter = "mhlo.scatter"(%base, %index, %update) ({
45    ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):
46      "mhlo.return"(%arg6) : (tensor<f32>) -> ()
47  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [3]>, unique_indices = true} : (tensor<512x1x6400x6400xf32>, tensor<1xi32>, tensor<512x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32>
48
49  // CHECK: %[[FOLD:.*]] = mhlo.constant dense<1.000000e+00> : tensor<512x1x6400x6400xf32>
50  // CHECK: return %[[FOLD]]
51  func.return %scatter :  tensor<512x1x6400x6400xf32>
52}
53
54// -----
55
56// Verify that a full overwrite of the "base" with a scatter is correctly folded
57// even if the base and update are not constant values.
58func.func @scatter_full_overwrite_non_const(
59        %base : tensor<512x1x6400x6400xf32>,
60        %update : tensor<512x1x6400x6400xf32>) ->  tensor<512x1x6400x6400xf32> {
61  %index = mhlo.constant dense<0> : tensor<1xi32>
62  %scatter = "mhlo.scatter"(%base, %index, %update) ({
63    ^bb0(%arg5: tensor<f32>, %arg6: tensor<f32>):
64      "mhlo.return"(%arg6) : (tensor<f32>) -> ()
65  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1, 2, 3], scatter_dims_to_operand_dims = [3]>, unique_indices = true} : (tensor<512x1x6400x6400xf32>, tensor<1xi32>, tensor<512x1x6400x6400xf32>) -> tensor<512x1x6400x6400xf32>
66
67  // CHECK: return %arg1
68  func.return %scatter :  tensor<512x1x6400x6400xf32>
69}
70
71// -----
72
73// Verify that a full overwrite of the "base" with a scatter is not folded when
74// there is a non-identity computation.
75func.func public @scatter_non_identity(%arg0: tensor<12xbf16>, %arg1: tensor<12xbf16>) -> tensor<12xbf16> {
76  %0 = mhlo.constant dense<0> : tensor<1xi32>
77  %1 = "mhlo.scatter"(%arg0, %0, %arg1) ({
78  ^bb0(%arg2: tensor<bf16>, %arg3: tensor<bf16>):
79    %2 = mhlo.add %arg2, %arg3 : tensor<bf16>
80    "mhlo.return"(%2) : (tensor<bf16>) -> ()
81  }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], scatter_dims_to_operand_dims = [0]>, unique_indices = true} : (tensor<12xbf16>, tensor<1xi32>, tensor<12xbf16>) -> tensor<12xbf16>
82  // CHECK: %[[SCATTER:.*]] = "mhlo.scatter
83  // CHECK: return %[[SCATTER]]
84  func.return %1 : tensor<12xbf16>
85}
86