• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/compiler/xla/permutation_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
30 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
31 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
35 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
36 #include "tensorflow/compiler/xla/types.h"
37 
38 namespace xla {
39 namespace {
40 
41 namespace m = xla::match;
42 
43 using HloConstantFoldingTest = HloTestBase;
44 
TEST_F(HloConstantFoldingTest,ConvertF32ToS64)45 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
46   HloComputation::Builder builder(TestName());
47   HloInstruction* input = builder.AddInstruction(
48       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
49   builder.AddInstruction(
50       HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
51 
52   auto module = CreateNewVerifiedModule();
53   auto computation = module->AddEntryComputation(builder.Build());
54 
55   EXPECT_THAT(computation->root_instruction(),
56               GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
57 
58   HloConstantFolding const_folder;
59   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
60   EXPECT_TRUE(result);
61 
62   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
63   EXPECT_EQ(
64       computation->root_instruction()->literal().GetFirstElement<int64_t>(),
65       42);
66 }
67 
TEST_F(HloConstantFoldingTest,ConvertS64ToF32)68 TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
69   HloComputation::Builder builder(TestName());
70   HloInstruction* input = builder.AddInstruction(
71       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(42)));
72   builder.AddInstruction(
73       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
74 
75   auto module = CreateNewVerifiedModule();
76   auto computation = module->AddEntryComputation(builder.Build());
77 
78   EXPECT_THAT(computation->root_instruction(),
79               GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
80 
81   HloConstantFolding const_folder;
82   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
83   EXPECT_TRUE(result);
84 
85   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
86   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
87             42.0f);
88 }
89 
TEST_F(HloConstantFoldingTest,ConvertF32ArrayToS64Array)90 TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
91   HloComputation::Builder builder(TestName());
92   HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
93       LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
94   builder.AddInstruction(
95       HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
96 
97   auto module = CreateNewVerifiedModule();
98   auto computation = module->AddEntryComputation(builder.Build());
99 
100   EXPECT_THAT(computation->root_instruction(),
101               GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
102 
103   HloConstantFolding const_folder;
104   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
105   EXPECT_TRUE(result);
106 
107   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
108   EXPECT_EQ(computation->root_instruction()->literal().Get<int64_t>({0}), 42);
109   EXPECT_EQ(computation->root_instruction()->literal().Get<int64_t>({1}), 19);
110 }
111 
TEST_F(HloConstantFoldingTest,Concatenate)112 TEST_F(HloConstantFoldingTest, Concatenate) {
113   const struct TestConfig {
114     int concat_dimension;
115     std::vector<int64_t> dimensions;
116     std::vector<int64_t> concat_sizes;
117   } test_configs[] = {
118       {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
119       {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
120   };
121 
122   for (auto& test_config : test_configs) {
123     HloComputation::Builder builder(TestName());
124     std::vector<int64_t> dimensions(test_config.dimensions.begin(),
125                                     test_config.dimensions.end());
126     int64_t concat_size = 0;
127     std::vector<HloInstruction*> operands;
128     for (auto csize : test_config.concat_sizes) {
129       dimensions[test_config.concat_dimension] = csize;
130       concat_size += csize;
131       auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
132       HloInstruction* insn = builder.AddInstruction(
133           HloInstruction::CreateConstant(std::move(literal)));
134       operands.push_back(insn);
135     }
136     dimensions[test_config.concat_dimension] = concat_size;
137     Shape shape = ShapeUtil::MakeShape(F32, dimensions);
138     builder.AddInstruction(HloInstruction::CreateConcatenate(
139         shape, operands, test_config.concat_dimension));
140     auto module = CreateNewVerifiedModule();
141     auto computation = module->AddEntryComputation(builder.Build());
142 
143     HloConstantFolding const_folder;
144     TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
145     EXPECT_TRUE(result);
146 
147     HloInstruction* root = computation->root_instruction();
148     EXPECT_THAT(root, GmockMatch(m::Constant()));
149     EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
150   }
151 }
152 
TEST_F(HloConstantFoldingTest,Slice)153 TEST_F(HloConstantFoldingTest, Slice) {
154   HloComputation::Builder builder(TestName());
155   const int64_t dimensions[] = {11, 8, 7, 5, 9};
156   const int64_t slice_start[] = {4, 2, 3, 1, 5};
157   const int64_t slice_limits[] = {10, 8, 6, 5, 9};
158   const int64_t slice_strides[] = {1, 1, 1, 1, 1};
159   TF_ASSERT_OK_AND_ASSIGN(auto literal,
160                           LiteralUtil::CreateRandomLiteral<F32>(
161                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
162   HloInstruction* literal_instruction = builder.AddInstruction(
163       HloInstruction::CreateConstant(std::move(literal)));
164   Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
165   builder.AddInstruction(HloInstruction::CreateSlice(
166       shape, literal_instruction, slice_start, slice_limits, slice_strides));
167   auto module = CreateNewVerifiedModule();
168   auto computation = module->AddEntryComputation(builder.Build());
169 
170   HloConstantFolding const_folder;
171   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
172   EXPECT_TRUE(result);
173 
174   HloInstruction* root = computation->root_instruction();
175   EXPECT_THAT(root, GmockMatch(m::Constant()));
176   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
177 }
178 
TEST_F(HloConstantFoldingTest,TransposeConstantFold)179 TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
180   HloComputation::Builder builder(TestName());
181   const int64_t dimensions[] = {11, 8, 7, 5, 9};
182   TF_ASSERT_OK_AND_ASSIGN(auto literal,
183                           LiteralUtil::CreateRandomLiteral<F32>(
184                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
185   auto literal_clone = literal.Clone();
186   HloInstruction* literal_instruction = builder.AddInstruction(
187       HloInstruction::CreateConstant(std::move(literal)));
188   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
189   const int64_t permutation[] = {1, 2, 0, 4, 3};
190   builder.AddInstruction(
191       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
192   auto module = CreateNewVerifiedModule();
193   auto computation = module->AddEntryComputation(builder.Build());
194 
195   HloConstantFolding const_folder;
196   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
197   EXPECT_TRUE(result);
198 
199   HloInstruction* root = computation->root_instruction();
200   EXPECT_THAT(root, GmockMatch(m::Constant()));
201   EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
202 
203   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
204   bool matched = true;
205   root->literal().EachCell<NativeT>(
206       [&](absl::Span<const int64_t> indices, NativeT value) {
207         std::vector<int64_t> rindexes = PermuteInverse(indices, permutation);
208         matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
209       });
210   EXPECT_TRUE(matched);
211 }
212 
213 const char* const kConstantFoldReduce = R"(
214   HloModule ConstantFoldReduce
215 
216   add {
217     a = s32[] parameter(0)
218     b = s32[] parameter(1)
219     ROOT add = s32[] add(a, b)
220   }
221 
222   ENTRY r {
223     x = s32[3] constant({1, 2, 3})
224     init = s32[] constant(0)
225     ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
226   })";
227 
TEST_F(HloConstantFoldingTest,ConstantFoldReduce)228 TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
229   TF_ASSERT_OK_AND_ASSIGN(auto m,
230                           ParseAndReturnVerifiedModule(kConstantFoldReduce));
231   HloConstantFolding const_folder;
232   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
233   EXPECT_TRUE(result);
234 
235   EXPECT_EQ(6, m->entry_computation()
236                    ->root_instruction()
237                    ->literal()
238                    .GetFirstElement<int32_t>());
239 }
240 
TEST_F(HloConstantFoldingTest,ConstantFoldReduceNoLayout)241 TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
242   TF_ASSERT_OK_AND_ASSIGN(auto m,
243                           ParseAndReturnVerifiedModule(kConstantFoldReduce));
244   HloInstruction* add = (*m->computations().begin())->root_instruction();
245   LayoutUtil::ClearLayout(add->mutable_shape());
246   HloConstantFolding const_folder;
247   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
248   EXPECT_FALSE(result);
249 
250   EXPECT_THAT(m->entry_computation()->root_instruction(),
251               GmockMatch(m::Reduce()));
252 }
253 
254 const char* const kConstantFoldLargePad = R"(
255   HloModule ConstantFoldLargePad
256 
257   ENTRY r {
258     a = f32[1,1,1] constant({{{7}}})
259     b = f32[] constant(42)
260     ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
261   })";
262 
TEST_F(HloConstantFoldingTest,DoesNotFoldLargePad)263 TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
264   TF_ASSERT_OK_AND_ASSIGN(auto module,
265                           ParseAndReturnVerifiedModule(kConstantFoldLargePad));
266   HloConstantFolding const_folder;
267   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
268   EXPECT_FALSE(result);
269 
270   EXPECT_THAT(module->entry_computation()->root_instruction(),
271               GmockMatch(m::Pad(m::Constant(), m::Constant())));
272 }
273 
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationContainingAfterAll)274 TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) {
275   const char* const kModuleStr = R"(
276   HloModule test
277 
278   Fn {
279     tok = token[] after-all()
280     ROOT root = f32[10] iota(), iota_dimension=0
281   }
282 
283   ENTRY entry {
284     ROOT call = f32[10] call(), to_apply=Fn
285   })";
286   TF_ASSERT_OK_AND_ASSIGN(auto module,
287                           ParseAndReturnVerifiedModule(kModuleStr));
288   HloConstantFolding constant_folding;
289   TF_ASSERT_OK_AND_ASSIGN(bool result,
290                           RunHloPass(&constant_folding, module.get()));
291   EXPECT_FALSE(result);
292 }
293 
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationTransitivelyContainingRng)294 TEST_F(HloConstantFoldingTest,
295        DontFoldSubcomputationTransitivelyContainingRng) {
296   const char* const kModuleStr = R"(
297   HloModule test
298 
299   InnerFn {
300     c0 = f32[] constant(0)
301     c1 = f32[] constant(1)
302     ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform
303   }
304 
305   Fn {
306     ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn
307   }
308 
309   ENTRY entry {
310     ROOT call = f32[10] call(), to_apply=Fn
311   })";
312   TF_ASSERT_OK_AND_ASSIGN(auto module,
313                           ParseAndReturnVerifiedModule(kModuleStr));
314   HloConstantFolding constant_folding;
315   TF_ASSERT_OK_AND_ASSIGN(bool result,
316                           RunHloPass(&constant_folding, module.get()));
317   EXPECT_FALSE(result);
318 }
319 
TEST_F(HloConstantFoldingTest,FoldOpsWhereOneOperandIsBroadcast)320 TEST_F(HloConstantFoldingTest, FoldOpsWhereOneOperandIsBroadcast) {
321   const char* const kModuleStr = R"(
322   HloModule test
323 
324   ENTRY entry {
325     not_folded1 = f32[4] broadcast(f32[] constant(1))
326     not_folded2 = add(f32[4] broadcast(f32[] constant(2)),
327                       f32[4] broadcast(f32[] constant(3)))
328     folded1 = add(f32[4] broadcast(f32[] constant(5)),
329                   f32[4] constant({0,1,2,3}))
330     folded2 = add(f32[4] constant({0,1,2,3}),
331                   f32[4] broadcast(f32[] constant(5)))
332     ROOT root = tuple(not_folded1, not_folded2, folded1, folded2)
333   })";
334   TF_ASSERT_OK_AND_ASSIGN(auto module,
335                           ParseAndReturnVerifiedModule(kModuleStr));
336   HloConstantFolding constant_folding;
337   TF_ASSERT_OK_AND_ASSIGN(bool result,
338                           RunHloPass(&constant_folding, module.get()));
339   EXPECT_TRUE(result);
340   EXPECT_THAT(module->entry_computation()->root_instruction(),
341               GmockMatch(m::Tuple(m::Broadcast(m::Constant()),
342                                   m::Add(m::Broadcast(m::Constant()),
343                                          m::Broadcast(m::Constant())),
344                                   m::Constant(),
345                                   m::Constant()  //
346                                   )));
347 }
348 
TEST_F(HloConstantFoldingTest,BigReduceWindow)349 TEST_F(HloConstantFoldingTest, BigReduceWindow) {
350   constexpr absl::string_view kModuleStr = R"(
351     HloModule test
352 
353     add_bf16 {
354       lhs = bf16[] parameter(0)
355       rhs = bf16[] parameter(1)
356       ROOT add = bf16[] add(lhs, rhs)
357     }
358 
359     ENTRY accumulated_all_reduce {
360       x = bf16[160,10,10,512]{3,2,1,0} broadcast(bf16[] constant(1.0))
361       init = bf16[] constant(0)
362       ROOT reduce-window = reduce-window(x, init), window={size=1x2x2x1 stride=1x2x2x1}, to_apply=add_bf16
363     }
364   )";
365   TF_ASSERT_OK_AND_ASSIGN(auto module,
366                           ParseAndReturnVerifiedModule(kModuleStr));
367   HloConstantFolding constant_folding;
368   TF_ASSERT_OK_AND_ASSIGN(bool result,
369                           RunHloPass(&constant_folding, module.get()));
370   EXPECT_TRUE(result);
371 }
372 
373 }  // namespace
374 }  // namespace xla
375