• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <array>
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace xla {
25 namespace {
26 
27 class TokenHloTest : public HloTestBase {};
28 
XLA_TEST_F(TokenHloTest,SingleTokenInstruction)29 XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
30   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
31   auto builder = HloComputation::Builder(TestName());
32   builder.AddInstruction(HloInstruction::CreateToken());
33 
34   module->AddEntryComputation(builder.Build());
35 
36   TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
37   EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
38 }
39 
XLA_TEST_F(TokenHloTest,TokenInTuple)40 XLA_TEST_F(TokenHloTest, TokenInTuple) {
41   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
42   auto builder = HloComputation::Builder(TestName());
43   auto token = builder.AddInstruction(HloInstruction::CreateToken());
44   builder.AddInstruction(HloInstruction::CreateTuple({token}));
45 
46   module->AddEntryComputation(builder.Build());
47 
48   TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
49   Literal token_literal = LiteralUtil::CreateToken();
50   EXPECT_TRUE(
51       LiteralTestUtil::Equal(result, LiteralUtil::MakeTuple({&token_literal})));
52 }
53 
XLA_TEST_F(TokenHloTest,TokenTree)54 XLA_TEST_F(TokenHloTest, TokenTree) {
55   std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
56   auto builder = HloComputation::Builder(TestName());
57   auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
58   auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
59   auto token2 = builder.AddInstruction(HloInstruction::CreateToken());
60   builder.AddInstruction(
61       HloInstruction::CreateAfterAll({token0, token0, token1, token2}));
62 
63   module->AddEntryComputation(builder.Build());
64 
65   TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
66   EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
67 }
68 
XLA_TEST_F(TokenHloTest,InvalidTokenShapedEntryParameter)69 XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
70   std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
71   auto builder = HloComputation::Builder(TestName());
72   builder.AddInstruction(
73       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
74   builder.AddInstruction(
75       HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1"));
76   builder.AddInstruction(
77       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42)));
78   module->AddEntryComputation(builder.Build());
79 
80   Status status =
81       HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
82           .Run(module.get())
83           .status();
84   ASSERT_IS_NOT_OK(status);
85   EXPECT_THAT(
86       status.error_message(),
87       ::testing::HasSubstr("Entry parameter 1 is or contains a token shape"));
88 }
89 
XLA_TEST_F(TokenHloTest,InvalidTupleTokenShapedEntryParameter)90 XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
91   std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
92   auto builder = HloComputation::Builder(TestName());
93   builder.AddInstruction(HloInstruction::CreateParameter(
94       0,
95       ShapeUtil::MakeTupleShape(
96           {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}),
97       "param"));
98   module->AddEntryComputation(builder.Build());
99 
100   Status status =
101       HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
102           .Run(module.get())
103           .status();
104   ASSERT_IS_NOT_OK(status);
105   EXPECT_THAT(
106       status.error_message(),
107       ::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
108 }
109 
XLA_TEST_F(TokenHloTest,TokenInWhileLoop)110 XLA_TEST_F(TokenHloTest, TokenInWhileLoop) {
111   // Thread a token around a while loop. Token is created and consumed by a
112   // AfterAll instruction in the while body.
113   std::string module_string = R"(
114 HloModule TokenInWhileLoop
115 
116 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
117   %param.1 = (s32[], token[]) parameter(0)
118   %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
119   %constant.1 = s32[] constant(1)
120   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
121   %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
122   %after-all = token[] after-all(token[] %get-tuple-element.2)
123   ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
124 }
125 
126 %Cond (param: (s32[], token[])) -> pred[] {
127   %param = (s32[], token[]) parameter(0)
128   %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
129   %constant = s32[] constant(42)
130   ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
131 }
132 
133 ENTRY %TokenInWhileLoop () -> s32[] {
134   %zero = s32[] constant(0)
135   %init_token = token[] after-all()
136   %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
137   %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
138   ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
139 }
140 )";
141 
142   DebugOptions debug_options = GetDebugOptionsForTest();
143   // Module DCE pass removes the generate token instructions.
144   debug_options.add_xla_disable_hlo_passes("hlo-module-dce");
145   TF_ASSERT_OK_AND_ASSIGN(
146       std::unique_ptr<HloModule> module,
147       HloRunner::CreateModuleFromString(module_string, debug_options));
148 
149   EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_));
150 }
151 
XLA_TEST_F(TokenHloTest,TokenInConditional)152 XLA_TEST_F(TokenHloTest, TokenInConditional) {
153   std::string module_string = R"(
154 HloModule TokenInConditional
155 
156 %True (param.1: token[]) -> (s32[], token[]) {
157   %param.1 = token[] parameter(0)
158   %forty_two = s32[] constant(42)
159   ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1)
160 }
161 
162 %False (param.2: s32[]) -> (s32[], token[]) {
163   %param.2 = s32[] parameter(0)
164   %new_token = token[] after-all()
165   ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token)
166 }
167 
168 ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
169   %param.3 = pred[] parameter(0)
170   %init_token = token[] after-all()
171   %seven = s32[] constant(7)
172   %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False
173   ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0
174 }
175 )";
176 
177   DebugOptions debug_options = GetDebugOptionsForTest();
178   // Module DCE pass removes the generate token instructions.
179   debug_options.add_xla_disable_hlo_passes("hlo-module-dce");
180 
181   {
182     // True case.
183     TF_ASSERT_OK_AND_ASSIGN(
184         std::unique_ptr<HloModule> module,
185         HloRunner::CreateModuleFromString(module_string, debug_options));
186     auto arg = LiteralUtil::CreateR0<bool>(true);
187     TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
188     EXPECT_EQ(42, result.Get<int32_t>({}));
189   }
190 
191   {
192     // False case.
193     TF_ASSERT_OK_AND_ASSIGN(
194         std::unique_ptr<HloModule> module,
195         HloRunner::CreateModuleFromString(module_string, debug_options));
196     auto arg = LiteralUtil::CreateR0<bool>(false);
197     TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
198     EXPECT_EQ(7, result.Get<int32_t>({}));
199   }
200 }
201 
XLA_TEST_F(TokenHloTest,AddDependency)202 XLA_TEST_F(TokenHloTest, AddDependency) {
203   std::string module_string = R"(
204 HloModule AddDependency, is_scheduled=true
205 
206 // Computes (p0 + 42) * (-p1)
207 // where there is a dependency from the add to the negation using a token
208 // with after-all and add-dependency instructions.
209 ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] {
210   %p0 = f32[] parameter(0)
211   %p1 = f32[] parameter(1)
212 
213   %forty_two = f32[] constant(42.0)
214   %add = f32[] add(f32[] %p0, f32[] %forty_two)
215   %token0 = token[] after-all(f32[] %add)
216   %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0)
217   %neg = f32[] negate(f32[] %p1_after_token)
218   ROOT %product = f32[] multiply(f32[] %add, f32[] %neg)
219 }
220 )";
221   TF_ASSERT_OK_AND_ASSIGN(
222       std::unique_ptr<HloModule> module,
223       ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
224   auto p0 = LiteralUtil::CreateR0<float>(10.0);
225   auto p1 = LiteralUtil::CreateR0<float>(3.0);
226   auto expected = LiteralUtil::CreateR0<float>(-156.0);
227   EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1}));
228 }
229 
XLA_TEST_F(TokenHloTest,AddDependencyOfConstant)230 XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) {
231   std::string module_string = R"(
232 HloModule AddDependencyOfConstant, is_scheduled=true
233 
234 ENTRY %AddDependency (p0: f32[]) -> f32[] {
235   %p0 = f32[] parameter(0)
236   %forty_two = f32[] constant(42.0)
237   %token0 = token[] after-all(f32[] %p0)
238   %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0)
239   ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token)
240 }
241 )";
242   TF_ASSERT_OK_AND_ASSIGN(
243       std::unique_ptr<HloModule> module,
244       ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
245   auto p0 = LiteralUtil::CreateR0<float>(10.0);
246   auto expected = LiteralUtil::CreateR0<float>(420.0);
247   EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0}));
248 }
249 
XLA_TEST_F(TokenHloTest,AddDependencyAsRoot)250 XLA_TEST_F(TokenHloTest, AddDependencyAsRoot) {
251   std::string module_string = R"(
252 HloModule AddDependencyAsRoot, is_scheduled=true
253 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
254   %p = f32[3] parameter(0)
255   %neg = f32[3] negate(f32[3] %p)
256   %token0 = token[] after-all()
257   ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0)
258 }
259 )";
260   TF_ASSERT_OK_AND_ASSIGN(
261       std::unique_ptr<HloModule> module,
262       ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
263   auto input = LiteralUtil::CreateR1<float>({1.0, 3.0, 7.0});
264   auto expected = LiteralUtil::CreateR1<float>({-1.0, -3.0, -7.0});
265   EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input}));
266 }
267 
XLA_TEST_F(TokenHloTest,TupleShapedAddDependency)268 XLA_TEST_F(TokenHloTest, TupleShapedAddDependency) {
269   std::string module_string = R"(
270 HloModule TupleShapedAddDependency, is_scheduled=true
271 ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] {
272   %p0 = f32[3] parameter(0)
273   %p1 = f32[3] parameter(1)
274   %forty_two = f32[] constant(42.0)
275   %token0 = token[] after-all()
276   %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two)
277   %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0)
278   %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0
279   %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2
280   ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2)
281 }
282 )";
283   TF_ASSERT_OK_AND_ASSIGN(
284       std::unique_ptr<HloModule> module,
285       ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
286   auto p0 = LiteralUtil::CreateR1<float>({3.0, 3.0, 47.0});
287   auto p1 = LiteralUtil::CreateR1<float>({1.0, -2.0, 2.0});
288   auto expected = LiteralUtil::CreateR1<float>({2.0, 5.0, 45.0});
289   EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1}));
290 }
291 
292 }  // namespace
293 }  // namespace xla
294