• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
17 
18 #include <optional>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/tests/filecheck.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace xla {
29 
30 namespace {
31 
32 class ReductionLayoutNormalizerTest : public HloTestBase {
33  public:
CheckReductionLayoutNormalizer(absl::string_view hlo,std::optional<absl::string_view> expected)34   void CheckReductionLayoutNormalizer(
35       absl::string_view hlo, std::optional<absl::string_view> expected) {
36     RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
37   }
38 };
39 
TEST_F(ReductionLayoutNormalizerTest,LayoutCanonicalizerTest)40 TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
41   const char* hlo = R"(
42 HloModule ReduceWithLayoutChange
43 
44 add {
45   x0 = f32[] parameter(0)
46   y0 = f32[] parameter(1)
47   ROOT add0 = f32[] add(x0, y0)
48 }
49 
50 ENTRY main {
51   arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
52   constant0 = f32[] constant(0)
53   ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
54     dimensions={1,6,7}, to_apply=add
55 }
56 
57 )";
58 
59   CheckReductionLayoutNormalizer(hlo,
60                                  R"(
61 // CHECK:  [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
62 // CHECK:  [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
63 // CHECK:  ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
64       )");
65 }
66 
TEST_F(ReductionLayoutNormalizerTest,LayoutCanonicalizerTestVariadic)67 TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
68   const char* hlo = R"(
69 HloModule ReduceWithLayoutChangeVariadic
70 
71 
72 argmax {
73   running_max = f32[] parameter(0)
74   running_max_idx = u32[] parameter(1)
75   current_value = f32[] parameter(2)
76   current_value_idx = u32[] parameter(3)
77 
78   current = (f32[], u32[]) tuple(running_max, running_max_idx)
79   potential = (f32[], u32[]) tuple(current_value, current_value_idx)
80 
81   cmp_code = pred[] compare(current_value, running_max), direction=GT
82 
83   new_max = f32[] select(cmp_code, current_value, running_max)
84   new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
85 
86   ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
87 }
88 
89 ENTRY main {
90   arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
91   idxs = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(1)
92   constant0 = f32[] constant(0)
93   constant1 = u32[] constant(0)
94   ROOT reduce0 = (
95       f32[4,5,16,12,12]{4,3,2,1,0},
96       u32[4,5,16,12,12]{4,3,2,1,0}
97     ) reduce(arg0, idxs, constant0,constant1), dimensions={1,6,7}, to_apply=argmax
98 }
99 
100 
101 )";
102 
103   CheckReductionLayoutNormalizer(hlo,
104                                  R"(
105 // CHECK:  [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
106 // CHECK:  [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
107 // CHECK:  [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
108 // CHECK:  [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
109 // CHECK:  [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
110 // CHECK:  [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
111 // CHECK:  [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
112 // CHECK:  [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
113 // CHECK:  [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
114 // CHECK:  ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
115       )");
116 }
117 
TEST_F(ReductionLayoutNormalizerTest,LayoutCanonicalizerTestVariadicDifferentLayouts)118 TEST_F(ReductionLayoutNormalizerTest,
119        LayoutCanonicalizerTestVariadicDifferentLayouts) {
120   const char* hlo = R"(
121 HloModule ReduceWithLayoutChangeVariadicDifferent
122 
123 argmax {
124   running_max = f32[] parameter(0)
125   running_max_idx = u32[] parameter(1)
126   current_value = f32[] parameter(2)
127   current_value_idx = u32[] parameter(3)
128 
129   current = (f32[], u32[]) tuple(running_max, running_max_idx)
130   potential = (f32[], u32[]) tuple(current_value, current_value_idx)
131 
132   cmp_code = pred[] compare(current_value, running_max), direction=GT
133 
134   new_max = f32[] select(cmp_code, current_value, running_max)
135   new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
136 
137   ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
138 }
139 
140 ENTRY main {
141   arg0 = f32[2,3,4,7]{2,1,0,3}  parameter(0)
142   idxs = u32[2,3,4,7]{3,2,1,0}  parameter(1)
143   constant0 = f32[] constant(0)
144   constant1 = u32[] constant(0)
145   ROOT reduce0 = (
146       f32[2,3,4]{2,1,0},
147       u32[2,3,4]{2,1,0}
148     ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
149 }
150 
151 
152 )";
153 
154   CheckReductionLayoutNormalizer(hlo,
155                                  R"(
156 // CHECK:  [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
157 // CHECK:  [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
158 // CHECK:  [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
159 // CHECK:  [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
160 // CHECK:  [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
161 // CHECK:  ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
162       )");
163   EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
164 }
165 
166 }  // namespace
167 }  // namespace xla
168