1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
18 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
19 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/test_helpers.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25
26 namespace xla {
27 namespace {
28
29 class WhileTransformerTest : public HloTestBase {
30 protected:
WhileTransformerTest()31 WhileTransformerTest()
32 : module_(CreateNewVerifiedModule()),
33 induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
34 data_shape_(ShapeUtil::MakeShape(F32, {8})),
35 condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
36
BuildConditionComputation(const int64_t tuple_index,const int64_t limit)37 std::unique_ptr<HloComputation> BuildConditionComputation(
38 const int64_t tuple_index, const int64_t limit) {
39 auto builder = HloComputation::Builder(TestName() + ".Condition");
40 auto limit_const = builder.AddInstruction(
41 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(limit)));
42 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
43 0, GetLoopStateShape(tuple_index), "loop_state"));
44 auto induction_variable =
45 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
46 limit_const->shape(), loop_state, tuple_index));
47 builder.AddInstruction(HloInstruction::CreateCompare(
48 condition_result_shape_, induction_variable, limit_const,
49 ComparisonDirection::kLt));
50 return builder.Build();
51 }
52
BuildBodyComputation(const int64_t ind_var_tuple_index,const int64_t data_tuple_index,const int64_t increment)53 std::unique_ptr<HloComputation> BuildBodyComputation(
54 const int64_t ind_var_tuple_index, const int64_t data_tuple_index,
55 const int64_t increment) {
56 auto builder = HloComputation::Builder(TestName() + ".Body");
57 // Create param instruction to access loop state.
58 auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
59 0, GetLoopStateShape(ind_var_tuple_index), "loop_state"));
60 // Update the induction variable GTE(ind_var_tuple_index).
61 auto induction_variable =
62 builder.AddInstruction(HloInstruction::CreateGetTupleElement(
63 induction_variable_shape_, loop_state, ind_var_tuple_index));
64 auto inc = builder.AddInstruction(HloInstruction::CreateConstant(
65 LiteralUtil::CreateR0<int32_t>(increment)));
66 auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
67 induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
68 // Update data GTE(data_tuple_index).
69 auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
70 data_shape_, loop_state, data_tuple_index));
71 // Use 'induction_variable' in computation with no path to output tuple.
72 auto cast = builder.AddInstruction(HloInstruction::CreateBitcastConvert(
73 ShapeUtil::MakeShape(F32, {}), induction_variable));
74 auto update = builder.AddInstruction(
75 HloInstruction::CreateBroadcast(data_shape_, cast, {}));
76 auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
77 data_shape_, HloOpcode::kAdd, data, update));
78 // Create output Tuple.
79 ind_var_tuple_index == 0
80 ? builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}))
81 : builder.AddInstruction(HloInstruction::CreateTuple({add1, add0}));
82 return builder.Build();
83 }
84
BuildWhileInstruction(HloComputation * condition,HloComputation * body,const int64_t ind_var_tuple_index,const int64_t ind_var_init)85 HloInstruction* BuildWhileInstruction(HloComputation* condition,
86 HloComputation* body,
87 const int64_t ind_var_tuple_index,
88 const int64_t ind_var_init) {
89 auto builder = HloComputation::Builder(TestName() + ".While");
90 auto induction_var_init =
91 builder.AddInstruction(HloInstruction::CreateConstant(
92 LiteralUtil::CreateR0<int32_t>(ind_var_init)));
93 auto data_init = builder.AddInstruction(
94 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
95 {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
96 auto loop_state_init =
97 ind_var_tuple_index == 0
98 ? builder.AddInstruction(
99 HloInstruction::CreateTuple({induction_var_init, data_init}))
100 : builder.AddInstruction(
101 HloInstruction::CreateTuple({data_init, induction_var_init}));
102 auto while_hlo = builder.AddInstruction(
103 HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index),
104 condition, body, loop_state_init));
105 module_->AddEntryComputation(builder.Build());
106 return while_hlo;
107 }
108
GetLoopStateShape(const int64_t ind_var_tuple_index)109 Shape GetLoopStateShape(const int64_t ind_var_tuple_index) {
110 if (ind_var_tuple_index == 0) {
111 return ShapeUtil::MakeTupleShape(
112 {induction_variable_shape_, data_shape_});
113 } else {
114 return ShapeUtil::MakeTupleShape(
115 {data_shape_, induction_variable_shape_});
116 }
117 }
118
119 std::unique_ptr<HloModule> module_;
120 Shape induction_variable_shape_;
121 Shape data_shape_;
122 Shape condition_result_shape_;
123 };
124
TEST_F(WhileTransformerTest,InductionVariableAtTupleElement0)125 TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
126 // Build computation with induction variable at tuple element 0.
127 auto condition =
128 module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
129 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
130 auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
131 auto result = ComputeWhileLoopTripCount(while_hlo);
132 ASSERT_TRUE(result);
133 EXPECT_EQ(10, *result);
134 }
135
TEST_F(WhileTransformerTest,InductionVariableAtTupleElement1)136 TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
137 // Build computation with induction variable at tuple element 1.
138 auto condition =
139 module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
140 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1));
141 auto while_hlo = BuildWhileInstruction(condition, body, 1, 0);
142 auto result = ComputeWhileLoopTripCount(while_hlo);
143 ASSERT_TRUE(result);
144 EXPECT_EQ(10, *result);
145 }
146
TEST_F(WhileTransformerTest,ImpossibleLoopLimit)147 TEST_F(WhileTransformerTest, ImpossibleLoopLimit) {
148 // Build computation with an impossible loop limit.
149 auto condition =
150 module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
151 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
152 auto while_hlo = BuildWhileInstruction(condition, body, 0, 10);
153 auto result = ComputeWhileLoopTripCount(while_hlo);
154 ASSERT_TRUE(result);
155 EXPECT_EQ(0, *result);
156 }
157
TEST_F(WhileTransformerTest,InvalidLoopIncrement)158 TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
159 // Build computation with invalid loop increment.
160 auto condition =
161 module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
162 auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1));
163 auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
164 auto result = ComputeWhileLoopTripCount(while_hlo);
165 ASSERT_FALSE(result);
166 }
167
168 } // namespace
169 } // namespace xla
170