1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <array>
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace xla {
25 namespace {
26
27 class TokenHloTest : public HloTestBase {};
28
XLA_TEST_F(TokenHloTest,SingleTokenInstruction)29 XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
30 std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
31 auto builder = HloComputation::Builder(TestName());
32 builder.AddInstruction(HloInstruction::CreateToken());
33
34 module->AddEntryComputation(builder.Build());
35
36 TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
37 EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
38 }
39
XLA_TEST_F(TokenHloTest,TokenInTuple)40 XLA_TEST_F(TokenHloTest, TokenInTuple) {
41 std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
42 auto builder = HloComputation::Builder(TestName());
43 auto token = builder.AddInstruction(HloInstruction::CreateToken());
44 builder.AddInstruction(HloInstruction::CreateTuple({token}));
45
46 module->AddEntryComputation(builder.Build());
47
48 TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
49 Literal token_literal = LiteralUtil::CreateToken();
50 EXPECT_TRUE(
51 LiteralTestUtil::Equal(result, LiteralUtil::MakeTuple({&token_literal})));
52 }
53
XLA_TEST_F(TokenHloTest,TokenTree)54 XLA_TEST_F(TokenHloTest, TokenTree) {
55 std::unique_ptr<HloModule> module = CreateNewVerifiedModule();
56 auto builder = HloComputation::Builder(TestName());
57 auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
58 auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
59 auto token2 = builder.AddInstruction(HloInstruction::CreateToken());
60 builder.AddInstruction(
61 HloInstruction::CreateAfterAll({token0, token0, token1, token2}));
62
63 module->AddEntryComputation(builder.Build());
64
65 TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {}));
66 EXPECT_TRUE(LiteralTestUtil::Equal(result, LiteralUtil::CreateToken()));
67 }
68
XLA_TEST_F(TokenHloTest,InvalidTokenShapedEntryParameter)69 XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
70 std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
71 auto builder = HloComputation::Builder(TestName());
72 builder.AddInstruction(
73 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
74 builder.AddInstruction(
75 HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1"));
76 builder.AddInstruction(
77 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(42)));
78 module->AddEntryComputation(builder.Build());
79
80 Status status =
81 HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
82 .Run(module.get())
83 .status();
84 ASSERT_IS_NOT_OK(status);
85 EXPECT_THAT(
86 status.error_message(),
87 ::testing::HasSubstr("Entry parameter 1 is or contains a token shape"));
88 }
89
XLA_TEST_F(TokenHloTest,InvalidTupleTokenShapedEntryParameter)90 XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
91 std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
92 auto builder = HloComputation::Builder(TestName());
93 builder.AddInstruction(HloInstruction::CreateParameter(
94 0,
95 ShapeUtil::MakeTupleShape(
96 {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}),
97 "param"));
98 module->AddEntryComputation(builder.Build());
99
100 Status status =
101 HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
102 .Run(module.get())
103 .status();
104 ASSERT_IS_NOT_OK(status);
105 EXPECT_THAT(
106 status.error_message(),
107 ::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
108 }
109
XLA_TEST_F(TokenHloTest,TokenInWhileLoop)110 XLA_TEST_F(TokenHloTest, TokenInWhileLoop) {
111 // Thread a token around a while loop. Token is created and consumed by a
112 // AfterAll instruction in the while body.
113 std::string module_string = R"(
114 HloModule TokenInWhileLoop
115
116 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
117 %param.1 = (s32[], token[]) parameter(0)
118 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
119 %constant.1 = s32[] constant(1)
120 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
121 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
122 %after-all = token[] after-all(token[] %get-tuple-element.2)
123 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
124 }
125
126 %Cond (param: (s32[], token[])) -> pred[] {
127 %param = (s32[], token[]) parameter(0)
128 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
129 %constant = s32[] constant(42)
130 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
131 }
132
133 ENTRY %TokenInWhileLoop () -> s32[] {
134 %zero = s32[] constant(0)
135 %init_token = token[] after-all()
136 %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
137 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
138 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
139 }
140 )";
141
142 DebugOptions debug_options = GetDebugOptionsForTest();
143 // Module DCE pass removes the generate token instructions.
144 debug_options.add_xla_disable_hlo_passes("hlo-module-dce");
145 TF_ASSERT_OK_AND_ASSIGN(
146 std::unique_ptr<HloModule> module,
147 HloRunner::CreateModuleFromString(module_string, debug_options));
148
149 EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_));
150 }
151
XLA_TEST_F(TokenHloTest,TokenInConditional)152 XLA_TEST_F(TokenHloTest, TokenInConditional) {
153 std::string module_string = R"(
154 HloModule TokenInConditional
155
156 %True (param.1: token[]) -> (s32[], token[]) {
157 %param.1 = token[] parameter(0)
158 %forty_two = s32[] constant(42)
159 ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1)
160 }
161
162 %False (param.2: s32[]) -> (s32[], token[]) {
163 %param.2 = s32[] parameter(0)
164 %new_token = token[] after-all()
165 ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token)
166 }
167
168 ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
169 %param.3 = pred[] parameter(0)
170 %init_token = token[] after-all()
171 %seven = s32[] constant(7)
172 %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False
173 ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0
174 }
175 )";
176
177 DebugOptions debug_options = GetDebugOptionsForTest();
178 // Module DCE pass removes the generate token instructions.
179 debug_options.add_xla_disable_hlo_passes("hlo-module-dce");
180
181 {
182 // True case.
183 TF_ASSERT_OK_AND_ASSIGN(
184 std::unique_ptr<HloModule> module,
185 HloRunner::CreateModuleFromString(module_string, debug_options));
186 auto arg = LiteralUtil::CreateR0<bool>(true);
187 TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
188 EXPECT_EQ(42, result.Get<int32_t>({}));
189 }
190
191 {
192 // False case.
193 TF_ASSERT_OK_AND_ASSIGN(
194 std::unique_ptr<HloModule> module,
195 HloRunner::CreateModuleFromString(module_string, debug_options));
196 auto arg = LiteralUtil::CreateR0<bool>(false);
197 TF_ASSERT_OK_AND_ASSIGN(Literal result, Execute(std::move(module), {&arg}));
198 EXPECT_EQ(7, result.Get<int32_t>({}));
199 }
200 }
201
XLA_TEST_F(TokenHloTest,AddDependency)202 XLA_TEST_F(TokenHloTest, AddDependency) {
203 std::string module_string = R"(
204 HloModule AddDependency, is_scheduled=true
205
206 // Computes (p0 + 42) * (-p1)
207 // where there is a dependency from the add to the negation using a token
208 // with after-all and add-dependency instructions.
209 ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] {
210 %p0 = f32[] parameter(0)
211 %p1 = f32[] parameter(1)
212
213 %forty_two = f32[] constant(42.0)
214 %add = f32[] add(f32[] %p0, f32[] %forty_two)
215 %token0 = token[] after-all(f32[] %add)
216 %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token0)
217 %neg = f32[] negate(f32[] %p1_after_token)
218 ROOT %product = f32[] multiply(f32[] %add, f32[] %neg)
219 }
220 )";
221 TF_ASSERT_OK_AND_ASSIGN(
222 std::unique_ptr<HloModule> module,
223 ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
224 auto p0 = LiteralUtil::CreateR0<float>(10.0);
225 auto p1 = LiteralUtil::CreateR0<float>(3.0);
226 auto expected = LiteralUtil::CreateR0<float>(-156.0);
227 EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1}));
228 }
229
XLA_TEST_F(TokenHloTest,AddDependencyOfConstant)230 XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) {
231 std::string module_string = R"(
232 HloModule AddDependencyOfConstant, is_scheduled=true
233
234 ENTRY %AddDependency (p0: f32[]) -> f32[] {
235 %p0 = f32[] parameter(0)
236 %forty_two = f32[] constant(42.0)
237 %token0 = token[] after-all(f32[] %p0)
238 %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token0)
239 ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token)
240 }
241 )";
242 TF_ASSERT_OK_AND_ASSIGN(
243 std::unique_ptr<HloModule> module,
244 ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
245 auto p0 = LiteralUtil::CreateR0<float>(10.0);
246 auto expected = LiteralUtil::CreateR0<float>(420.0);
247 EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0}));
248 }
249
XLA_TEST_F(TokenHloTest,AddDependencyAsRoot)250 XLA_TEST_F(TokenHloTest, AddDependencyAsRoot) {
251 std::string module_string = R"(
252 HloModule AddDependencyAsRoot, is_scheduled=true
253 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
254 %p = f32[3] parameter(0)
255 %neg = f32[3] negate(f32[3] %p)
256 %token0 = token[] after-all()
257 ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token0)
258 }
259 )";
260 TF_ASSERT_OK_AND_ASSIGN(
261 std::unique_ptr<HloModule> module,
262 ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
263 auto input = LiteralUtil::CreateR1<float>({1.0, 3.0, 7.0});
264 auto expected = LiteralUtil::CreateR1<float>({-1.0, -3.0, -7.0});
265 EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input}));
266 }
267
XLA_TEST_F(TokenHloTest,TupleShapedAddDependency)268 XLA_TEST_F(TokenHloTest, TupleShapedAddDependency) {
269 std::string module_string = R"(
270 HloModule TupleShapedAddDependency, is_scheduled=true
271 ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] {
272 %p0 = f32[3] parameter(0)
273 %p1 = f32[3] parameter(1)
274 %forty_two = f32[] constant(42.0)
275 %token0 = token[] after-all()
276 %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token0, f32[3] %p1, f32[] %forty_two)
277 %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token0)
278 %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0
279 %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2
280 ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2)
281 }
282 )";
283 TF_ASSERT_OK_AND_ASSIGN(
284 std::unique_ptr<HloModule> module,
285 ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
286 auto p0 = LiteralUtil::CreateR1<float>({3.0, 3.0, 47.0});
287 auto p1 = LiteralUtil::CreateR1<float>({1.0, -2.0, 2.0});
288 auto expected = LiteralUtil::CreateR1<float>({2.0, 5.0, 45.0});
289 EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1}));
290 }
291
292 } // namespace
293 } // namespace xla
294