• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 
32 namespace op = xla::testing::opcode_matchers;
33 
34 namespace xla {
35 namespace {
36 
37 class TupleSimplifierTest : public HloTestBase {
38  protected:
Run(HloModule * module,bool change_expected)39   void Run(HloModule* module, bool change_expected) {
40     auto changed_status = RunHloPass(TupleSimplifier(), module);
41     TF_ASSERT_OK(changed_status.status());
42     EXPECT_EQ(change_expected, changed_status.ValueOrDie());
43   }
Run(HloModule * module,bool change_expected,bool exclude_entry)44   void Run(HloModule* module, bool change_expected, bool exclude_entry) {
45     auto changed_status = RunHloPass(TupleSimplifier(exclude_entry), module);
46     TF_ASSERT_OK(changed_status.status());
47     EXPECT_EQ(change_expected, changed_status.ValueOrDie());
48   }
49 
50   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
51   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
52       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}),
53        ShapeUtil::MakeShape(F32, {})});
54 };
55 
TEST_F(TupleSimplifierTest,TupleOfParameters)56 TEST_F(TupleSimplifierTest, TupleOfParameters) {
57   // A Tuple constructed of a bunch of parameters should not be changed.
58   HloComputation::Builder builder(TestName());
59   HloInstruction* param0 = builder.AddInstruction(
60       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
61   HloInstruction* param1 = builder.AddInstruction(
62       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
63   HloInstruction* param2 = builder.AddInstruction(
64       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
65   builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2}));
66   auto module = CreateNewVerifiedModule();
67   module->AddEntryComputation(builder.Build());
68 
69   Run(module.get(), /*change_expected=*/false);
70 }
71 
TEST_F(TupleSimplifierTest,GteOfTupleOfParameter)72 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) {
73   // A GTE of a tuple parameter should not be changed.
74   HloComputation::Builder builder(TestName());
75   HloInstruction* param = builder.AddInstruction(
76       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
77   builder.AddInstruction(
78       HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
79   auto module = CreateNewVerifiedModule();
80   module->AddEntryComputation(builder.Build());
81 
82   Run(module.get(), /*change_expected=*/false);
83 }
84 
TEST_F(TupleSimplifierTest,GteOfTuple)85 TEST_F(TupleSimplifierTest, GteOfTuple) {
86   // A GTE of a Tuple should be short-circuited.
87   HloComputation::Builder builder(TestName());
88   HloInstruction* param0 = builder.AddInstruction(
89       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
90   HloInstruction* param1 = builder.AddInstruction(
91       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
92   HloInstruction* param2 = builder.AddInstruction(
93       HloInstruction::CreateParameter(2, scalar_shape_, "param2"));
94   HloInstruction* tuple = builder.AddInstruction(
95       HloInstruction::CreateTuple({param0, param1, param2}));
96   HloInstruction* gte = builder.AddInstruction(
97       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
98 
99   auto module = CreateNewVerifiedModule();
100   auto computation = module->AddEntryComputation(builder.Build());
101 
102   EXPECT_THAT(computation->root_instruction(), gte);
103 
104   Run(module.get(), /*change_expected=*/true);
105 
106   EXPECT_THAT(computation->root_instruction(), param1);
107 }
108 
TEST_F(TupleSimplifierTest,GteOfTupleChain)109 TEST_F(TupleSimplifierTest, GteOfTupleChain) {
110   // Verify a chain of GTE/Tuple instructions is collapsed.
111   HloComputation::Builder builder(TestName());
112   HloInstruction* param = builder.AddInstruction(
113       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
114 
115   const int kChainLength = 10;
116   HloInstruction* element = param;
117   for (int i = 0; i < kChainLength; ++i) {
118     HloInstruction* tuple = builder.AddInstruction(
119         HloInstruction::CreateTuple({element, element, element}));
120     element = builder.AddInstruction(
121         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
122   }
123   builder.AddInstruction(
124       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element));
125 
126   auto module = CreateNewVerifiedModule();
127   auto computation = module->AddEntryComputation(builder.Build());
128 
129   EXPECT_THAT(computation->root_instruction(),
130               op::Negate(op::GetTupleElement(op::Tuple())));
131 
132   Run(module.get(), /*change_expected=*/true);
133 
134   EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter()));
135 }
136 
TEST_F(TupleSimplifierTest,NestedGteOfTuples)137 TEST_F(TupleSimplifierTest, NestedGteOfTuples) {
138   // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested
139   // to some depth with a chain of Tuple instructions, then extracted with a
140   // chain of GTE instructions.
141   HloComputation::Builder builder(TestName());
142   HloInstruction* param = builder.AddInstruction(
143       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
144 
145   const int kNestingDepth = 5;
146   HloInstruction* nested_tuple = param;
147   for (int i = 0; i < kNestingDepth; ++i) {
148     nested_tuple = builder.AddInstruction(
149         HloInstruction::CreateTuple({nested_tuple, nested_tuple}));
150   }
151 
152   HloInstruction* element = nested_tuple;
153   for (int i = 0; i < kNestingDepth; ++i) {
154     element = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
155         ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0));
156   }
157 
158   auto module = CreateNewVerifiedModule();
159   auto computation = module->AddEntryComputation(builder.Build());
160 
161   EXPECT_THAT(computation->root_instruction(), element);
162 
163   Run(module.get(), /*change_expected=*/true);
164 
165   EXPECT_THAT(computation->root_instruction(), param);
166 }
167 
TEST_F(TupleSimplifierTest,TupleOfGteInstructions)168 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) {
169   // Verify that a tuple constructed of GTE instructions operating on the same
170   // tuple are collapsed.
171   HloComputation::Builder builder(TestName());
172   HloInstruction* tuple_param = builder.AddInstruction(
173       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
174   HloInstruction* gte0 = builder.AddInstruction(
175       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
176   HloInstruction* gte1 = builder.AddInstruction(
177       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
178   HloInstruction* gte2 = builder.AddInstruction(
179       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2));
180   HloInstruction* tuple =
181       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
182 
183   auto module = CreateNewVerifiedModule();
184   auto computation = module->AddEntryComputation(builder.Build());
185 
186   EXPECT_THAT(computation->root_instruction(), tuple);
187 
188   Run(module.get(), /*change_expected=*/true);
189 
190   EXPECT_THAT(computation->root_instruction(), tuple_param);
191 }
192 
TEST_F(TupleSimplifierTest,IncompatibleTuples)193 TEST_F(TupleSimplifierTest, IncompatibleTuples) {
194   // Verify that a tuple->GTE->tuple construct is not simplified if the input
195   // and output tuple are not compatible shapes.
196   HloComputation::Builder builder(TestName());
197   HloInstruction* tuple_param = builder.AddInstruction(
198       HloInstruction::CreateParameter(0, tuple_shape_, "param"));
199   HloInstruction* gte0 = builder.AddInstruction(
200       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0));
201   HloInstruction* gte1 = builder.AddInstruction(
202       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1));
203   // Output tuple has only two elements. Parameter tuple has three elements so
204   // simplification is not possible.
205   HloInstruction* tuple =
206       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
207 
208   auto module = CreateNewVerifiedModule();
209   auto computation = module->AddEntryComputation(builder.Build());
210 
211   EXPECT_THAT(computation->root_instruction(), tuple);
212 
213   Run(module.get(), /*change_expected=*/false);
214 
215   EXPECT_THAT(computation->root_instruction(), tuple);
216 }
217 
TEST_F(TupleSimplifierTest,CanExcludeEntryComputation)218 TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
219   //  Verify that the root computation can be excluded
220   auto module = CreateNewVerifiedModule();
221 
222   HloInstruction* p0;
223   HloInstruction* p1;
224   HloComputation* c0;
225   HloComputation* c1;
226   HloComputation* entry;
227 
228   {
229     HloComputation::Builder builder(TestName() + "_1");
230     p0 = builder.AddInstruction(
231         HloInstruction::CreateParameter(0, tuple_shape_, "param"));
232     HloInstruction* gte0 = builder.AddInstruction(
233         HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
234     HloInstruction* gte1 = builder.AddInstruction(
235         HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
236     HloInstruction* gte2 = builder.AddInstruction(
237         HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));
238 
239     builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
240 
241     c0 = module->AddEmbeddedComputation(builder.Build());
242   }
243   {
244     HloComputation::Builder builder(TestName() + "_2");
245     p1 = builder.AddInstruction(
246         HloInstruction::CreateParameter(0, tuple_shape_, "param"));
247     HloInstruction* gte0 = builder.AddInstruction(
248         HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
249     HloInstruction* gte1 = builder.AddInstruction(
250         HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
251     HloInstruction* gte2 = builder.AddInstruction(
252         HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));
253 
254     builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
255 
256     c1 = module->AddEmbeddedComputation(builder.Build());
257   }
258   {
259     HloComputation::Builder builder(TestName() + "_Entry");
260     HloInstruction* tuple_param = builder.AddInstruction(
261         HloInstruction::CreateParameter(0, tuple_shape_, "param"));
262     HloInstruction* call0 = builder.AddInstruction(
263         HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
264     HloInstruction* call1 = builder.AddInstruction(
265         HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
266     HloInstruction* gte0 = builder.AddInstruction(
267         HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
268     HloInstruction* gte1 = builder.AddInstruction(
269         HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
270     HloInstruction* tuple0 =
271         builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
272     HloInstruction* gte2 = builder.AddInstruction(
273         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
274     HloInstruction* gte3 = builder.AddInstruction(
275         HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));
276 
277     builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));
278 
279     entry = module->AddEntryComputation(builder.Build());
280   }
281 
282   Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/true);
283 
284   EXPECT_THAT(c0->root_instruction(), p0);
285   EXPECT_THAT(c1->root_instruction(), p1);
286   EXPECT_THAT(entry->instruction_count(), 9);
287 }
288 
TEST_F(TupleSimplifierTest,ShardingLoss)289 TEST_F(TupleSimplifierTest, ShardingLoss) {
290   const char* kModuleStr = R"(
291     HloModule m
292 
293     ENTRY test {
294       p0 = s32[10] parameter(0), sharding={devices=[2]0,1}
295       t = (s32[10]) tuple(p0)
296       ROOT %gte = s32[10] get-tuple-element(t), index=0, sharding={replicated}
297     }
298   )";
299   TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
300   Run(m.get(), /*change_expected=*/false);
301 }
302 
303 }  // namespace
304 }  // namespace xla
305