• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
18 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
19 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/test_helpers.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace xla {
27 namespace {
28 
29 class WhileTransformerTest : public HloTestBase {
30  protected:
WhileTransformerTest()31   WhileTransformerTest()
32       : module_(CreateNewVerifiedModule()),
33         induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
34         data_shape_(ShapeUtil::MakeShape(F32, {8})),
35         condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
36 
BuildConditionComputation(const int64_t tuple_index,const int64_t limit)37   std::unique_ptr<HloComputation> BuildConditionComputation(
38       const int64_t tuple_index, const int64_t limit) {
39     auto builder = HloComputation::Builder(TestName() + ".Condition");
40     auto limit_const = builder.AddInstruction(
41         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(limit)));
42     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
43         0, GetLoopStateShape(tuple_index), "loop_state"));
44     auto induction_variable =
45         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
46             limit_const->shape(), loop_state, tuple_index));
47     builder.AddInstruction(HloInstruction::CreateCompare(
48         condition_result_shape_, induction_variable, limit_const,
49         ComparisonDirection::kLt));
50     return builder.Build();
51   }
52 
BuildBodyComputation(const int64_t ind_var_tuple_index,const int64_t data_tuple_index,const int64_t increment)53   std::unique_ptr<HloComputation> BuildBodyComputation(
54       const int64_t ind_var_tuple_index, const int64_t data_tuple_index,
55       const int64_t increment) {
56     auto builder = HloComputation::Builder(TestName() + ".Body");
57     // Create param instruction to access loop state.
58     auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
59         0, GetLoopStateShape(ind_var_tuple_index), "loop_state"));
60     // Update the induction variable GTE(ind_var_tuple_index).
61     auto induction_variable =
62         builder.AddInstruction(HloInstruction::CreateGetTupleElement(
63             induction_variable_shape_, loop_state, ind_var_tuple_index));
64     auto inc = builder.AddInstruction(HloInstruction::CreateConstant(
65         LiteralUtil::CreateR0<int32_t>(increment)));
66     auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
67         induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
68     // Update data GTE(data_tuple_index).
69     auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
70         data_shape_, loop_state, data_tuple_index));
71     // Use 'induction_variable' in computation with no path to output tuple.
72     auto cast = builder.AddInstruction(HloInstruction::CreateBitcastConvert(
73         ShapeUtil::MakeShape(F32, {}), induction_variable));
74     auto update = builder.AddInstruction(
75         HloInstruction::CreateBroadcast(data_shape_, cast, {}));
76     auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
77         data_shape_, HloOpcode::kAdd, data, update));
78     // Create output Tuple.
79     ind_var_tuple_index == 0
80         ? builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}))
81         : builder.AddInstruction(HloInstruction::CreateTuple({add1, add0}));
82     return builder.Build();
83   }
84 
BuildWhileInstruction(HloComputation * condition,HloComputation * body,const int64_t ind_var_tuple_index,const int64_t ind_var_init)85   HloInstruction* BuildWhileInstruction(HloComputation* condition,
86                                         HloComputation* body,
87                                         const int64_t ind_var_tuple_index,
88                                         const int64_t ind_var_init) {
89     auto builder = HloComputation::Builder(TestName() + ".While");
90     auto induction_var_init =
91         builder.AddInstruction(HloInstruction::CreateConstant(
92             LiteralUtil::CreateR0<int32_t>(ind_var_init)));
93     auto data_init = builder.AddInstruction(
94         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
95             {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
96     auto loop_state_init =
97         ind_var_tuple_index == 0
98             ? builder.AddInstruction(
99                   HloInstruction::CreateTuple({induction_var_init, data_init}))
100             : builder.AddInstruction(
101                   HloInstruction::CreateTuple({data_init, induction_var_init}));
102     auto while_hlo = builder.AddInstruction(
103         HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index),
104                                     condition, body, loop_state_init));
105     module_->AddEntryComputation(builder.Build());
106     return while_hlo;
107   }
108 
GetLoopStateShape(const int64_t ind_var_tuple_index)109   Shape GetLoopStateShape(const int64_t ind_var_tuple_index) {
110     if (ind_var_tuple_index == 0) {
111       return ShapeUtil::MakeTupleShape(
112           {induction_variable_shape_, data_shape_});
113     } else {
114       return ShapeUtil::MakeTupleShape(
115           {data_shape_, induction_variable_shape_});
116     }
117   }
118 
119   std::unique_ptr<HloModule> module_;
120   Shape induction_variable_shape_;
121   Shape data_shape_;
122   Shape condition_result_shape_;
123 };
124 
TEST_F(WhileTransformerTest,InductionVariableAtTupleElement0)125 TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
126   // Build computation with induction variable at tuple element 0.
127   auto condition =
128       module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
129   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
130   auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
131   auto result = ComputeWhileLoopTripCount(while_hlo);
132   ASSERT_TRUE(result);
133   EXPECT_EQ(10, *result);
134 }
135 
TEST_F(WhileTransformerTest,InductionVariableAtTupleElement1)136 TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
137   // Build computation with induction variable at tuple element 1.
138   auto condition =
139       module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
140   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(1, 0, 1));
141   auto while_hlo = BuildWhileInstruction(condition, body, 1, 0);
142   auto result = ComputeWhileLoopTripCount(while_hlo);
143   ASSERT_TRUE(result);
144   EXPECT_EQ(10, *result);
145 }
146 
TEST_F(WhileTransformerTest,ImpossibleLoopLimit)147 TEST_F(WhileTransformerTest, ImpossibleLoopLimit) {
148   // Build computation with an impossible loop limit.
149   auto condition =
150       module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
151   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
152   auto while_hlo = BuildWhileInstruction(condition, body, 0, 10);
153   auto result = ComputeWhileLoopTripCount(while_hlo);
154   ASSERT_TRUE(result);
155   EXPECT_EQ(0, *result);
156 }
157 
TEST_F(WhileTransformerTest,InvalidLoopIncrement)158 TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
159   // Build computation with invalid loop increment.
160   auto condition =
161       module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
162   auto body = module_->AddEmbeddedComputation(BuildBodyComputation(0, 1, -1));
163   auto while_hlo = BuildWhileInstruction(condition, body, 0, 0);
164   auto result = ComputeWhileLoopTripCount(while_hlo);
165   ASSERT_FALSE(result);
166 }
167 
168 }  // namespace
169 }  // namespace xla
170