1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/permutation_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/types.h"
37
38 namespace xla {
39 namespace {
40
41 namespace m = xla::match;
42
43 using HloConstantFoldingTest = HloTestBase;
44
TEST_F(HloConstantFoldingTest,ConvertF32ToS64)45 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
46 HloComputation::Builder builder(TestName());
47 HloInstruction* input = builder.AddInstruction(
48 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
49 builder.AddInstruction(
50 HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
51
52 auto module = CreateNewVerifiedModule();
53 auto computation = module->AddEntryComputation(builder.Build());
54
55 EXPECT_THAT(computation->root_instruction(),
56 GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
57
58 HloConstantFolding const_folder;
59 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
60 EXPECT_TRUE(result);
61
62 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
63 EXPECT_EQ(
64 computation->root_instruction()->literal().GetFirstElement<int64_t>(),
65 42);
66 }
67
TEST_F(HloConstantFoldingTest,ConvertS64ToF32)68 TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
69 HloComputation::Builder builder(TestName());
70 HloInstruction* input = builder.AddInstruction(
71 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(42)));
72 builder.AddInstruction(
73 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
74
75 auto module = CreateNewVerifiedModule();
76 auto computation = module->AddEntryComputation(builder.Build());
77
78 EXPECT_THAT(computation->root_instruction(),
79 GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
80
81 HloConstantFolding const_folder;
82 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
83 EXPECT_TRUE(result);
84
85 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
86 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
87 42.0f);
88 }
89
TEST_F(HloConstantFoldingTest,ConvertF32ArrayToS64Array)90 TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
91 HloComputation::Builder builder(TestName());
92 HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
93 LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
94 builder.AddInstruction(
95 HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
96
97 auto module = CreateNewVerifiedModule();
98 auto computation = module->AddEntryComputation(builder.Build());
99
100 EXPECT_THAT(computation->root_instruction(),
101 GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
102
103 HloConstantFolding const_folder;
104 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
105 EXPECT_TRUE(result);
106
107 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
108 EXPECT_EQ(computation->root_instruction()->literal().Get<int64_t>({0}), 42);
109 EXPECT_EQ(computation->root_instruction()->literal().Get<int64_t>({1}), 19);
110 }
111
TEST_F(HloConstantFoldingTest,Concatenate)112 TEST_F(HloConstantFoldingTest, Concatenate) {
113 const struct TestConfig {
114 int concat_dimension;
115 std::vector<int64_t> dimensions;
116 std::vector<int64_t> concat_sizes;
117 } test_configs[] = {
118 {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
119 {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
120 };
121
122 for (auto& test_config : test_configs) {
123 HloComputation::Builder builder(TestName());
124 std::vector<int64_t> dimensions(test_config.dimensions.begin(),
125 test_config.dimensions.end());
126 int64_t concat_size = 0;
127 std::vector<HloInstruction*> operands;
128 for (auto csize : test_config.concat_sizes) {
129 dimensions[test_config.concat_dimension] = csize;
130 concat_size += csize;
131 auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
132 HloInstruction* insn = builder.AddInstruction(
133 HloInstruction::CreateConstant(std::move(literal)));
134 operands.push_back(insn);
135 }
136 dimensions[test_config.concat_dimension] = concat_size;
137 Shape shape = ShapeUtil::MakeShape(F32, dimensions);
138 builder.AddInstruction(HloInstruction::CreateConcatenate(
139 shape, operands, test_config.concat_dimension));
140 auto module = CreateNewVerifiedModule();
141 auto computation = module->AddEntryComputation(builder.Build());
142
143 HloConstantFolding const_folder;
144 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
145 EXPECT_TRUE(result);
146
147 HloInstruction* root = computation->root_instruction();
148 EXPECT_THAT(root, GmockMatch(m::Constant()));
149 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
150 }
151 }
152
TEST_F(HloConstantFoldingTest,Slice)153 TEST_F(HloConstantFoldingTest, Slice) {
154 HloComputation::Builder builder(TestName());
155 const int64_t dimensions[] = {11, 8, 7, 5, 9};
156 const int64_t slice_start[] = {4, 2, 3, 1, 5};
157 const int64_t slice_limits[] = {10, 8, 6, 5, 9};
158 const int64_t slice_strides[] = {1, 1, 1, 1, 1};
159 TF_ASSERT_OK_AND_ASSIGN(auto literal,
160 LiteralUtil::CreateRandomLiteral<F32>(
161 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
162 HloInstruction* literal_instruction = builder.AddInstruction(
163 HloInstruction::CreateConstant(std::move(literal)));
164 Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
165 builder.AddInstruction(HloInstruction::CreateSlice(
166 shape, literal_instruction, slice_start, slice_limits, slice_strides));
167 auto module = CreateNewVerifiedModule();
168 auto computation = module->AddEntryComputation(builder.Build());
169
170 HloConstantFolding const_folder;
171 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
172 EXPECT_TRUE(result);
173
174 HloInstruction* root = computation->root_instruction();
175 EXPECT_THAT(root, GmockMatch(m::Constant()));
176 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
177 }
178
TEST_F(HloConstantFoldingTest,TransposeConstantFold)179 TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
180 HloComputation::Builder builder(TestName());
181 const int64_t dimensions[] = {11, 8, 7, 5, 9};
182 TF_ASSERT_OK_AND_ASSIGN(auto literal,
183 LiteralUtil::CreateRandomLiteral<F32>(
184 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
185 auto literal_clone = literal.Clone();
186 HloInstruction* literal_instruction = builder.AddInstruction(
187 HloInstruction::CreateConstant(std::move(literal)));
188 Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
189 const int64_t permutation[] = {1, 2, 0, 4, 3};
190 builder.AddInstruction(
191 HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
192 auto module = CreateNewVerifiedModule();
193 auto computation = module->AddEntryComputation(builder.Build());
194
195 HloConstantFolding const_folder;
196 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
197 EXPECT_TRUE(result);
198
199 HloInstruction* root = computation->root_instruction();
200 EXPECT_THAT(root, GmockMatch(m::Constant()));
201 EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
202
203 using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
204 bool matched = true;
205 root->literal().EachCell<NativeT>(
206 [&](absl::Span<const int64_t> indices, NativeT value) {
207 std::vector<int64_t> rindexes = PermuteInverse(indices, permutation);
208 matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
209 });
210 EXPECT_TRUE(matched);
211 }
212
213 const char* const kConstantFoldReduce = R"(
214 HloModule ConstantFoldReduce
215
216 add {
217 a = s32[] parameter(0)
218 b = s32[] parameter(1)
219 ROOT add = s32[] add(a, b)
220 }
221
222 ENTRY r {
223 x = s32[3] constant({1, 2, 3})
224 init = s32[] constant(0)
225 ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
226 })";
227
TEST_F(HloConstantFoldingTest,ConstantFoldReduce)228 TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
229 TF_ASSERT_OK_AND_ASSIGN(auto m,
230 ParseAndReturnVerifiedModule(kConstantFoldReduce));
231 HloConstantFolding const_folder;
232 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
233 EXPECT_TRUE(result);
234
235 EXPECT_EQ(6, m->entry_computation()
236 ->root_instruction()
237 ->literal()
238 .GetFirstElement<int32_t>());
239 }
240
TEST_F(HloConstantFoldingTest,ConstantFoldReduceNoLayout)241 TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
242 TF_ASSERT_OK_AND_ASSIGN(auto m,
243 ParseAndReturnVerifiedModule(kConstantFoldReduce));
244 HloInstruction* add = (*m->computations().begin())->root_instruction();
245 LayoutUtil::ClearLayout(add->mutable_shape());
246 HloConstantFolding const_folder;
247 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
248 EXPECT_FALSE(result);
249
250 EXPECT_THAT(m->entry_computation()->root_instruction(),
251 GmockMatch(m::Reduce()));
252 }
253
254 const char* const kConstantFoldLargePad = R"(
255 HloModule ConstantFoldLargePad
256
257 ENTRY r {
258 a = f32[1,1,1] constant({{{7}}})
259 b = f32[] constant(42)
260 ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
261 })";
262
TEST_F(HloConstantFoldingTest,DoesNotFoldLargePad)263 TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
264 TF_ASSERT_OK_AND_ASSIGN(auto module,
265 ParseAndReturnVerifiedModule(kConstantFoldLargePad));
266 HloConstantFolding const_folder;
267 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
268 EXPECT_FALSE(result);
269
270 EXPECT_THAT(module->entry_computation()->root_instruction(),
271 GmockMatch(m::Pad(m::Constant(), m::Constant())));
272 }
273
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationContainingAfterAll)274 TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) {
275 const char* const kModuleStr = R"(
276 HloModule test
277
278 Fn {
279 tok = token[] after-all()
280 ROOT root = f32[10] iota(), iota_dimension=0
281 }
282
283 ENTRY entry {
284 ROOT call = f32[10] call(), to_apply=Fn
285 })";
286 TF_ASSERT_OK_AND_ASSIGN(auto module,
287 ParseAndReturnVerifiedModule(kModuleStr));
288 HloConstantFolding constant_folding;
289 TF_ASSERT_OK_AND_ASSIGN(bool result,
290 RunHloPass(&constant_folding, module.get()));
291 EXPECT_FALSE(result);
292 }
293
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationTransitivelyContainingRng)294 TEST_F(HloConstantFoldingTest,
295 DontFoldSubcomputationTransitivelyContainingRng) {
296 const char* const kModuleStr = R"(
297 HloModule test
298
299 InnerFn {
300 c0 = f32[] constant(0)
301 c1 = f32[] constant(1)
302 ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform
303 }
304
305 Fn {
306 ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn
307 }
308
309 ENTRY entry {
310 ROOT call = f32[10] call(), to_apply=Fn
311 })";
312 TF_ASSERT_OK_AND_ASSIGN(auto module,
313 ParseAndReturnVerifiedModule(kModuleStr));
314 HloConstantFolding constant_folding;
315 TF_ASSERT_OK_AND_ASSIGN(bool result,
316 RunHloPass(&constant_folding, module.get()));
317 EXPECT_FALSE(result);
318 }
319
TEST_F(HloConstantFoldingTest,FoldOpsWhereOneOperandIsBroadcast)320 TEST_F(HloConstantFoldingTest, FoldOpsWhereOneOperandIsBroadcast) {
321 const char* const kModuleStr = R"(
322 HloModule test
323
324 ENTRY entry {
325 not_folded1 = f32[4] broadcast(f32[] constant(1))
326 not_folded2 = add(f32[4] broadcast(f32[] constant(2)),
327 f32[4] broadcast(f32[] constant(3)))
328 folded1 = add(f32[4] broadcast(f32[] constant(5)),
329 f32[4] constant({0,1,2,3}))
330 folded2 = add(f32[4] constant({0,1,2,3}),
331 f32[4] broadcast(f32[] constant(5)))
332 ROOT root = tuple(not_folded1, not_folded2, folded1, folded2)
333 })";
334 TF_ASSERT_OK_AND_ASSIGN(auto module,
335 ParseAndReturnVerifiedModule(kModuleStr));
336 HloConstantFolding constant_folding;
337 TF_ASSERT_OK_AND_ASSIGN(bool result,
338 RunHloPass(&constant_folding, module.get()));
339 EXPECT_TRUE(result);
340 EXPECT_THAT(module->entry_computation()->root_instruction(),
341 GmockMatch(m::Tuple(m::Broadcast(m::Constant()),
342 m::Add(m::Broadcast(m::Constant()),
343 m::Broadcast(m::Constant())),
344 m::Constant(),
345 m::Constant() //
346 )));
347 }
348
TEST_F(HloConstantFoldingTest,BigReduceWindow)349 TEST_F(HloConstantFoldingTest, BigReduceWindow) {
350 constexpr absl::string_view kModuleStr = R"(
351 HloModule test
352
353 add_bf16 {
354 lhs = bf16[] parameter(0)
355 rhs = bf16[] parameter(1)
356 ROOT add = bf16[] add(lhs, rhs)
357 }
358
359 ENTRY accumulated_all_reduce {
360 x = bf16[160,10,10,512]{3,2,1,0} broadcast(bf16[] constant(1.0))
361 init = bf16[] constant(0)
362 ROOT reduce-window = reduce-window(x, init), window={size=1x2x2x1 stride=1x2x2x1}, to_apply=add_bf16
363 }
364 )";
365 TF_ASSERT_OK_AND_ASSIGN(auto module,
366 ParseAndReturnVerifiedModule(kModuleStr));
367 HloConstantFolding constant_folding;
368 TF_ASSERT_OK_AND_ASSIGN(bool result,
369 RunHloPass(&constant_folding, module.get()));
370 EXPECT_TRUE(result);
371 }
372
373 } // namespace
374 } // namespace xla
375