1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h"
17
18 #include <optional>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_parser.h"
23 #include "tensorflow/compiler/xla/tests/filecheck.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace xla {
29
30 namespace {
31
32 class ReductionLayoutNormalizerTest : public HloTestBase {
33 public:
CheckReductionLayoutNormalizer(absl::string_view hlo,std::optional<absl::string_view> expected)34 void CheckReductionLayoutNormalizer(
35 absl::string_view hlo, std::optional<absl::string_view> expected) {
36 RunAndFilecheckHloRewrite(hlo, gpu::ReductionLayoutNormalizer{}, expected);
37 }
38 };
39
TEST_F(ReductionLayoutNormalizerTest,LayoutCanonicalizerTest)40 TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTest) {
41 const char* hlo = R"(
42 HloModule ReduceWithLayoutChange
43
44 add {
45 x0 = f32[] parameter(0)
46 y0 = f32[] parameter(1)
47 ROOT add0 = f32[] add(x0, y0)
48 }
49
50 ENTRY main {
51 arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
52 constant0 = f32[] constant(0)
53 ROOT reduce0 = f32[4,5,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
54 dimensions={1,6,7}, to_apply=add
55 }
56
57 )";
58
59 CheckReductionLayoutNormalizer(hlo,
60 R"(
61 // CHECK: [[bitcast_0:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_1:%[^ ]+]])
62 // CHECK: [[reduce_2:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} reduce([[bitcast_0]], [[constant0_3:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[add_4:%[^ ]+]]
63 // CHECK: ROOT [[bitcast_1_5:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[reduce_2]])
64 )");
65 }
66
TEST_F(ReductionLayoutNormalizerTest,LayoutCanonicalizerTestVariadic)67 TEST_F(ReductionLayoutNormalizerTest, LayoutCanonicalizerTestVariadic) {
68 const char* hlo = R"(
69 HloModule ReduceWithLayoutChangeVariadic
70
71
72 argmax {
73 running_max = f32[] parameter(0)
74 running_max_idx = u32[] parameter(1)
75 current_value = f32[] parameter(2)
76 current_value_idx = u32[] parameter(3)
77
78 current = (f32[], u32[]) tuple(running_max, running_max_idx)
79 potential = (f32[], u32[]) tuple(current_value, current_value_idx)
80
81 cmp_code = pred[] compare(current_value, running_max), direction=GT
82
83 new_max = f32[] select(cmp_code, current_value, running_max)
84 new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
85
86 ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
87 }
88
89 ENTRY main {
90 arg0 = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
91 idxs = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
92 constant0 = f32[] constant(0)
93 constant1 = u32[] constant(0)
94 ROOT reduce0 = (
95 f32[4,5,16,12,12]{4,3,2,1,0},
96 u32[4,5,16,12,12]{4,3,2,1,0}
97 ) reduce(arg0, idxs, constant0,constant1), dimensions={1,6,7}, to_apply=argmax
98 }
99
100
101 )";
102
103 CheckReductionLayoutNormalizer(hlo,
104 R"(
105 // CHECK: [[arg0_0:%[^ ]+]] = f32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
106 // CHECK: [[bitcast_1:%[^ ]+]] = f32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[arg0_0]])
107 // CHECK: [[idxs_2:%[^ ]+]] = u32[4,5,5,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(1)
108 // CHECK: [[bitcast_1_3:%[^ ]+]] = u32[5,3,3,4,12,12,16,5]{7,6,5,4,3,2,1,0} bitcast([[idxs_2]])
109 // CHECK: [[reduce_4:%[^ ]+]] = (f32[4,12,12,16,5]{2,1,3,4,0}, u32[4,12,12,16,5]{2,1,3,4,0}) reduce([[bitcast_1]], [[bitcast_1_3]], [[constant0_5:%[^ ]+]], [[constant1_6:%[^ ]+]]), dimensions={0,1,2}, to_apply=[[argmax_7:%[^ ]+]]
110 // CHECK: [[get_tuple_element_8:%[^ ]+]] = f32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=0
111 // CHECK: [[bitcast_2_9:%[^ ]+]] = f32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_8]])
112 // CHECK: [[get_tuple_element_1_10:%[^ ]+]] = u32[4,12,12,16,5]{2,1,3,4,0} get-tuple-element([[reduce_4]]), index=1
113 // CHECK: [[bitcast_3_11:%[^ ]+]] = u32[4,5,16,12,12]{4,3,2,1,0} bitcast([[get_tuple_element_1_10]])
114 // CHECK: ROOT [[tuple_12:%[^ ]+]] = (f32[4,5,16,12,12]{4,3,2,1,0}, u32[4,5,16,12,12]{4,3,2,1,0}) tuple([[bitcast_2_9]], [[bitcast_3_11]])
115 )");
116 }
117
TEST_F(ReductionLayoutNormalizerTest,LayoutCanonicalizerTestVariadicDifferentLayouts)118 TEST_F(ReductionLayoutNormalizerTest,
119 LayoutCanonicalizerTestVariadicDifferentLayouts) {
120 const char* hlo = R"(
121 HloModule ReduceWithLayoutChangeVariadicDifferent
122
123 argmax {
124 running_max = f32[] parameter(0)
125 running_max_idx = u32[] parameter(1)
126 current_value = f32[] parameter(2)
127 current_value_idx = u32[] parameter(3)
128
129 current = (f32[], u32[]) tuple(running_max, running_max_idx)
130 potential = (f32[], u32[]) tuple(current_value, current_value_idx)
131
132 cmp_code = pred[] compare(current_value, running_max), direction=GT
133
134 new_max = f32[] select(cmp_code, current_value, running_max)
135 new_idx = u32[] select(cmp_code, current_value_idx, running_max_idx)
136
137 ROOT out = (f32[], u32[]) tuple(new_max, new_idx)
138 }
139
140 ENTRY main {
141 arg0 = f32[2,3,4,7]{2,1,0,3} parameter(0)
142 idxs = u32[2,3,4,7]{3,2,1,0} parameter(1)
143 constant0 = f32[] constant(0)
144 constant1 = u32[] constant(0)
145 ROOT reduce0 = (
146 f32[2,3,4]{2,1,0},
147 u32[2,3,4]{2,1,0}
148 ) reduce(arg0, idxs, constant0,constant1), dimensions={3}, to_apply=argmax
149 }
150
151
152 )";
153
154 CheckReductionLayoutNormalizer(hlo,
155 R"(
156 // CHECK: [[arg0_0:%[^ ]+]] = f32[2,3,4,7]{2,1,0,3} parameter(0)
157 // CHECK: [[bitcast_1:%[^ ]+]] = f32[7,2,3,4]{3,2,1,0} bitcast([[arg0_0]])
158 // CHECK: [[idxs_2:%[^ ]+]] = u32[2,3,4,7]{3,2,1,0} parameter(1)
159 // CHECK: [[copy_3:%[^ ]+]] = u32[2,3,4,7]{2,1,0,3} copy([[idxs_2]])
160 // CHECK: [[bitcast_1_4:%[^ ]+]] = u32[7,2,3,4]{3,2,1,0} bitcast([[copy_3]])
161 // CHECK: ROOT [[reduce0_5:%[^ ]+]] = (f32[2,3,4]{2,1,0}, u32[2,3,4]{2,1,0}) reduce([[bitcast_1]], [[bitcast_1_4]], [[constant0_6:%[^ ]+]], [[constant1_7:%[^ ]+]]), dimensions={0}, to_apply=[[argmax_8:%[^ ]+]]
162 )");
163 EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5}));
164 }
165
166 } // namespace
167 } // namespace xla
168