1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <string>
19
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/debug_options_flags.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/test.h"
27 #include "tensorflow/compiler/xla/test_helpers.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30
31 namespace xla {
32
33 namespace {
34
35 namespace op = xla::testing::opcode_matchers;
36
37 using ::testing::HasSubstr;
38
39 // TODO(b/74197823): Move the tests to service/.
40 class XlaBuilderTest : public ::testing::Test {
41 protected:
BuildHloModule(XlaBuilder * b)42 StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) {
43 TF_ASSIGN_OR_RETURN(XlaComputation computation,
44 b->Build(/*remove_dynamic_dimensions=*/false));
45 const HloModuleProto& proto = computation.proto();
46 TF_ASSIGN_OR_RETURN(const auto& config,
47 HloModule::CreateModuleConfigFromProto(
48 proto, GetDebugOptionsFromFlags()));
49 return HloModule::CreateFromProto(proto, config);
50 }
51
52 // Overload which explicitly specifies the root instruction.
BuildHloModule(XlaBuilder * b,XlaOp root)53 StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
54 XlaOp root) {
55 TF_ASSIGN_OR_RETURN(XlaComputation computation,
56 b->Build(root, /*remove_dynamic_dimensions=*/false));
57 const HloModuleProto& proto = computation.proto();
58 TF_ASSIGN_OR_RETURN(const auto& config,
59 HloModule::CreateModuleConfigFromProto(
60 proto, GetDebugOptionsFromFlags()));
61 return HloModule::CreateFromProto(proto, config);
62 }
63
64 // Returns the name of the test currently being run.
TestName() const65 string TestName() const {
66 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
67 }
68 };
69
TEST_F(XlaBuilderTest,OnePlusTwo)70 TEST_F(XlaBuilderTest, OnePlusTwo) {
71 XlaBuilder b(TestName());
72 Add(ConstantR0<float>(&b, 1.0), ConstantR0<float>(&b, 2.0));
73 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
74 auto root = module->entry_computation()->root_instruction();
75 EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
76 }
77
TEST_F(XlaBuilderTest,UnaryOperatorsBuildExpectedHLO)78 TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) {
79 auto test_unary_operator =
80 [&](std::function<XlaOp(XlaOp)> op,
81 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
82 XlaBuilder b(TestName());
83 op(ConstantR0<int32>(&b, 1));
84 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
85 auto root = module->entry_computation()->root_instruction();
86 EXPECT_THAT(root, matches_pattern);
87 };
88 test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant()));
89 test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant()));
90 }
91
TEST_F(XlaBuilderTest,BinaryOperatorsBuildExpectedHLO)92 TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) {
93 auto test_binary_operator =
94 [&](std::function<XlaOp(XlaOp, XlaOp)> op,
95 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
96 XlaBuilder b(TestName());
97 op(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
98 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
99 auto root = module->entry_computation()->root_instruction();
100 EXPECT_THAT(root, matches_pattern);
101 };
102
103 test_binary_operator([](XlaOp x, XlaOp y) { return x + y; },
104 op::Add(op::Constant(), op::Constant()));
105 test_binary_operator([](XlaOp x, XlaOp y) { return x - y; },
106 op::Subtract(op::Constant(), op::Constant()));
107 test_binary_operator([](XlaOp x, XlaOp y) { return x * y; },
108 op::Multiply(op::Constant(), op::Constant()));
109 test_binary_operator([](XlaOp x, XlaOp y) { return x / y; },
110 op::Divide(op::Constant(), op::Constant()));
111
112 test_binary_operator([](XlaOp x, XlaOp y) { return x & y; },
113 op::And(op::Constant(), op::Constant()));
114 test_binary_operator([](XlaOp x, XlaOp y) { return x | y; },
115 op::Or(op::Constant(), op::Constant()));
116 test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; },
117 op::Xor(op::Constant(), op::Constant()));
118 test_binary_operator([](XlaOp x, XlaOp y) { return x << y; },
119 op::ShiftLeft(op::Constant(), op::Constant()));
120 test_binary_operator(
121 [](XlaOp x, XlaOp y) { return x >> y; },
122 op::ShiftRightArithmetic(op::Constant(), op::Constant()));
123
124 auto test_unsigned_binary_operator =
125 [&](std::function<XlaOp(XlaOp, XlaOp)> op,
126 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
127 XlaBuilder b(TestName());
128 op(ConstantR0<uint32>(&b, 1), ConstantR0<uint32>(&b, 2));
129 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
130 auto root = module->entry_computation()->root_instruction();
131 EXPECT_THAT(root, matches_pattern);
132 };
133 test_unsigned_binary_operator(
134 [](XlaOp x, XlaOp y) { return x >> y; },
135 op::ShiftRightLogical(op::Constant(), op::Constant()));
136 }
137
TEST_F(XlaBuilderTest,VariadicAnd)138 TEST_F(XlaBuilderTest, VariadicAnd) {
139 XlaBuilder b(TestName());
140 Shape s = ShapeUtil::MakeShape(PRED, {});
141 And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
142 Parameter(&b, 2, s, "p2"));
143 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
144 // Don't specify in the test whether And(x, y, z) is right- or
145 // left-associative; accept either one.
146 EXPECT_THAT(
147 module->entry_computation()->root_instruction(),
148 ::testing::AnyOf(op::And(op::Parameter(0),
149 op::And(op::Parameter(1), op::Parameter(2))),
150 op::And(op::And(op::Parameter(0), op::Parameter(1)),
151 op::Parameter(2))));
152 }
153
TEST_F(XlaBuilderTest,VariadicOr)154 TEST_F(XlaBuilderTest, VariadicOr) {
155 XlaBuilder b(TestName());
156 Shape s = ShapeUtil::MakeShape(PRED, {});
157 Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
158 Parameter(&b, 2, s, "p2"));
159 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
160 // Don't specify in the test whether Or(x, y, z) is right- or
161 // left-associative; accept either one.
162 EXPECT_THAT(
163 module->entry_computation()->root_instruction(),
164 ::testing::AnyOf(
165 op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))),
166 op::Or(op::Or(op::Parameter(0), op::Parameter(1)),
167 op::Parameter(2))));
168 }
169
TEST_F(XlaBuilderTest,ShiftRightOperatorOnNonIntegerProducesError)170 TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) {
171 XlaBuilder b(TestName());
172 ConstantR0<float>(&b, 1) >> ConstantR0<float>(&b, 2);
173 auto statusor = b.Build();
174 ASSERT_FALSE(statusor.ok());
175 EXPECT_THAT(
176 statusor.status().error_message(),
177 HasSubstr("Argument to >> operator does not have an integral type"));
178 }
179
TEST_F(XlaBuilderTest,ParamPlusConstantHasScalarBroadcast)180 TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
181 XlaBuilder b(TestName());
182 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
183 Add(x, ConstantR0<float>(&b, 1.0));
184 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
185 auto root = module->entry_computation()->root_instruction();
186 EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant())));
187 }
188
TEST_F(XlaBuilderTest,ParamPlusParamHasBroadcast)189 TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
190 XlaBuilder b(TestName());
191 const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
192 const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
193 auto x = Parameter(&b, 0, x_shape, "x");
194 auto y = Parameter(&b, 1, y_shape, "y");
195 auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1});
196
197 TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add));
198 EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
199
200 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
201 auto root = module->entry_computation()->root_instruction();
202 EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1))));
203 }
204
TEST_F(XlaBuilderTest,XPlusX)205 TEST_F(XlaBuilderTest, XPlusX) {
206 XlaBuilder b(TestName());
207 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
208 Add(x, x);
209 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
210 auto root = module->entry_computation()->root_instruction();
211 EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
212 }
213
TEST_F(XlaBuilderTest,ShapeInferenceError)214 TEST_F(XlaBuilderTest, ShapeInferenceError) {
215 XlaBuilder b(TestName());
216 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
217 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
218 Add(x, y);
219 auto statusor = BuildHloModule(&b);
220 ASSERT_FALSE(statusor.ok());
221 EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference"));
222 }
223
TEST_F(XlaBuilderTest,ParameterAlreadyRegistered)224 TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
225 XlaBuilder b_call("add");
226 Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x");
227
228 XlaBuilder b(TestName());
229 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x");
230 auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y");
231 Add(x, y);
232 auto statusor = BuildHloModule(&b);
233 ASSERT_FALSE(statusor.ok());
234 EXPECT_THAT(statusor.status().error_message(),
235 HasSubstr("parameter 0 already registered"));
236 }
237
TEST_F(XlaBuilderTest,Call)238 TEST_F(XlaBuilderTest, Call) {
239 XlaBuilder b_call("the_only_to_apply");
240 auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
241 auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
242 Add(p0, p1);
243 TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
244 XlaBuilder b(TestName());
245 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
246 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
247 auto one = ConstantR0<float>(&b, 1);
248 auto two = ConstantR0<float>(&b, 2);
249 Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
250 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
251 auto root = module->entry_computation()->root_instruction();
252 EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
253 op::Call(op::Constant(), op::Constant())));
254 }
255
TEST_F(XlaBuilderTest,BinopHasDegenerateBroadcast)256 TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
257 XlaBuilder b(TestName());
258 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
259 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
260 Add(x, y);
261 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
262
263 // Expected:
264 //
265 // x: f32[1,2,3] y: f32[1,2,1]
266 // | |
267 // | reshape: f32[1,2]
268 // | |
269 // | broadcast: f32[1,2,3]
270 // \ /
271 // add
272 auto root = module->entry_computation()->root_instruction();
273 EXPECT_THAT(root, op::Add(op::Parameter(0),
274 op::Broadcast(op::Reshape(op::Parameter(1)))));
275 }
276
TEST_F(XlaBuilderTest,BinopHasInDimAndDegenerateBroadcast)277 TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
278 XlaBuilder b(TestName());
279 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
280 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
281 Add(x, y, /*broadcast_dimensions=*/{0, 1});
282 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
283
284 // The binary operation has in-dim broadcast and degenerate broadcast, should
285 // first do the in-dim broadcast then convert the degnerate broadcast into a
286 // reshape and a broadcast.
287 //
288 // Expected:
289 //
290 // x: f32[2,3] y: f32[2,1,4]
291 // | |
292 // broadcast: f32[2,3,4] reshape: f32[2,4]
293 // | |
294 // | broadcast: f32[2,3,4]
295 // \ /
296 // add
297 auto root = module->entry_computation()->root_instruction();
298 EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)),
299 op::Broadcast(op::Reshape(op::Parameter(1)))));
300 }
301
TEST_F(XlaBuilderTest,BroadcastInDim)302 TEST_F(XlaBuilderTest, BroadcastInDim) {
303 XlaBuilder b(TestName());
304 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
305 BroadcastInDim(x, {2, 4, 3},
306 /*broadcast_dimensions=*/{0, 2});
307 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
308 auto root = module->entry_computation()->root_instruction();
309 EXPECT_THAT(root, op::Broadcast());
310 }
311
TEST_F(XlaBuilderTest,BroadcastInDimWithDegeneratedDim)312 TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) {
313 XlaBuilder b(TestName());
314 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
315 BroadcastInDim(x, {2, 3, 4},
316 /*broadcast_dimensions=*/{0, 1, 2});
317 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
318 EXPECT_THAT(module->entry_computation()->root_instruction(),
319 op::Broadcast(op::Reshape(op::Broadcast())));
320 }
321
TEST_F(XlaBuilderTest,OperandFromWrongBuilder)322 TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
323 XlaBuilder b1("b1");
324 auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0");
325 XlaBuilder builder("main");
326 auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p");
327 Add(p, p0);
328 auto statusor = builder.Build();
329 ASSERT_FALSE(statusor.ok());
330 EXPECT_THAT(
331 statusor.status().error_message(),
332 HasSubstr(
333 "built by builder 'b1', but is trying to use it in builder 'main'"));
334 }
335
TEST_F(XlaBuilderTest,ReshapeDefaultOrder)336 TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
337 XlaBuilder b(TestName());
338 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
339 Reshape(x, /*new_sizes=*/{6, 35});
340 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
341 auto root = module->entry_computation()->root_instruction();
342 EXPECT_THAT(root, op::Reshape(op::Parameter()));
343 }
344
TEST_F(XlaBuilderTest,ReshapeHasTranspose)345 TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
346 XlaBuilder b(TestName());
347 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
348 Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
349 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
350 auto root = module->entry_computation()->root_instruction();
351 EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter())));
352 }
353
TEST_F(XlaBuilderTest,Transpose)354 TEST_F(XlaBuilderTest, Transpose) {
355 XlaBuilder b(TestName());
356 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
357 Transpose(x, /*permutation=*/{1, 0});
358 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
359 auto root = module->entry_computation()->root_instruction();
360 EXPECT_THAT(root, op::Transpose(op::Parameter()));
361 }
362
TEST_F(XlaBuilderTest,AllToAll)363 TEST_F(XlaBuilderTest, AllToAll) {
364 XlaBuilder b(TestName());
365 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
366 AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
367 /*split_count=*/2);
368 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
369 auto root = module->entry_computation()->root_instruction();
370
371 // AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
372 EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
373 EXPECT_EQ(root->operand(0)->operand(0)->opcode(), HloOpcode::kAllToAll);
374 EXPECT_TRUE(
375 ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
376 }
377
TEST_F(XlaBuilderTest,CollectivePermute)378 TEST_F(XlaBuilderTest, CollectivePermute) {
379 XlaBuilder b(TestName());
380 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
381 CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}});
382 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
383 auto root = module->entry_computation()->root_instruction();
384 EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute);
385 }
386
TEST_F(XlaBuilderTest,GetDimensionSize)387 TEST_F(XlaBuilderTest, GetDimensionSize) {
388 XlaBuilder b(TestName());
389 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
390 GetDimensionSize(x, 1);
391 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
392 auto root = module->entry_computation()->root_instruction();
393 EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
394 }
395
TEST_F(XlaBuilderTest,ReportError)396 TEST_F(XlaBuilderTest, ReportError) {
397 XlaBuilder b(TestName());
398 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
399 Add(b.ReportError(InvalidArgument("a test error")), x);
400 auto statusor = b.Build();
401 ASSERT_FALSE(statusor.ok());
402 EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
403 }
404
TEST_F(XlaBuilderTest,ReportErrorOrReturnHandlesNonErrors)405 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) {
406 XlaBuilder b(TestName());
407 StatusOr<XlaOp> op(ConstantR0<float>(&b, 1.0));
408 Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
409 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
410 auto root = module->entry_computation()->root_instruction();
411 EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
412 }
413
TEST_F(XlaBuilderTest,ReportErrorOrReturnHandlesErrors)414 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
415 XlaBuilder b(TestName());
416 StatusOr<XlaOp> op(InvalidArgument("a test error"));
417 Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
418 auto statusor = b.Build();
419 ASSERT_FALSE(statusor.ok());
420 EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
421 }
422
TEST_F(XlaBuilderTest,BuildWithSpecificRoot)423 TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
424 XlaBuilder b(TestName());
425 XlaOp constant = ConstantR0<float>(&b, 1.0);
426 Add(constant, ConstantR0<float>(&b, 2.0));
427 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
428 auto root = module->entry_computation()->root_instruction();
429 EXPECT_THAT(root, op::Constant());
430 }
431
TEST_F(XlaBuilderTest,BuildWithSpecificRootAndMultipleParameters)432 TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
433 // Specifying a particular root in Build should still include all entry
434 // parameters.
435 XlaBuilder b(TestName());
436 const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
437 XlaOp x = Parameter(&b, 0, shape, "x");
438 XlaOp y = Parameter(&b, 1, shape, "y");
439 XlaOp z = Parameter(&b, 2, shape, "z");
440 Add(x, Sub(y, z));
441 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
442 auto root = module->entry_computation()->root_instruction();
443 EXPECT_THAT(root, op::Parameter());
444 EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
445 EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
446 }
447
TEST_F(XlaBuilderTest,BuildWithSpecificRootWithWrongBuilder)448 TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
449 XlaBuilder b(TestName());
450 XlaBuilder other_b(TestName());
451 const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
452
453 Parameter(&b, 0, shape, "param");
454 XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
455
456 Status status = b.Build(other_param).status();
457 ASSERT_IS_NOT_OK(status);
458 EXPECT_THAT(
459 status.error_message(),
460 ::testing::HasSubstr("root operation is not in this computation"));
461 }
462
TEST_F(XlaBuilderTest,ProtoMatches)463 TEST_F(XlaBuilderTest, ProtoMatches) {
464 std::vector<XlaComputation> computations;
465 for (int i = 0; i < 2; ++i) {
466 XlaBuilder b_call("the_only_to_apply");
467 auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
468 auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
469 Add(p0, Add(p1, p0));
470 TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
471 XlaBuilder b(TestName());
472 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
473 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
474 auto one = ConstantR0<float>(&b, 1);
475 auto two = ConstantR0<float>(&b, 2);
476 Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
477 computations.push_back(b.Build().ValueOrDie());
478 }
479 auto c0_string = computations[0].proto().SerializeAsString();
480 auto c1_string = computations[1].proto().SerializeAsString();
481 EXPECT_EQ(c0_string, c1_string);
482 }
483
TEST_F(XlaBuilderTest,DynamicParameter)484 TEST_F(XlaBuilderTest, DynamicParameter) {
485 XlaBuilder b(TestName());
486 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
487 {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})});
488 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
489 Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1");
490 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1,
491 /*dynamic_size_param_index=*/{},
492 /*target_param_num=*/0,
493 /*target_param_index=*/{1},
494 /*target_dim_num=*/0));
495 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0));
496 const Shape& param_shape = module->entry_computation()
497 ->parameter_instruction(0)
498 ->shape()
499 .tuple_shapes(1);
500 EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
501 }
502
TEST_F(XlaBuilderTest,DynamicUnary)503 TEST_F(XlaBuilderTest, DynamicUnary) {
504 XlaBuilder b(TestName());
505 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
506 {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
507 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
508 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
509 /*dynamic_size_param_index=*/{1},
510 /*target_param_num=*/0,
511 /*target_param_index=*/{0},
512 /*target_dim_num=*/0));
513 auto gte = GetTupleElement(p0, 0);
514 Neg(gte);
515 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
516 const Shape& result_shape =
517 module->entry_computation()->root_instruction()->shape();
518 EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
519 }
520
TEST_F(XlaBuilderTest,DynamicBinary)521 TEST_F(XlaBuilderTest, DynamicBinary) {
522 XlaBuilder b(TestName());
523 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
524 {ShapeUtil::MakeShape(F32, {5}, {true}),
525 ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
526 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
527 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
528 /*dynamic_size_param_index=*/{2},
529 /*target_param_num=*/0,
530 /*target_param_index=*/{0},
531 /*target_dim_num=*/0));
532 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
533 /*dynamic_size_param_index=*/{2},
534 /*target_param_num=*/0,
535 /*target_param_index=*/{1},
536 /*target_dim_num=*/0));
537 auto gte0 = GetTupleElement(p0, 0);
538 auto gte1 = GetTupleElement(p0, 1);
539 Add(gte0, gte1);
540 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
541 const Shape& result_shape =
542 module->entry_computation()->root_instruction()->shape();
543 EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
544 }
545
TEST_F(XlaBuilderTest,DynamicBinaryHasBroadcast)546 TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) {
547 XlaBuilder b(TestName());
548 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
549 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
550 ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
551 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
552 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
553 /*dynamic_size_param_index=*/{2},
554 /*target_param_num=*/0,
555 /*target_param_index=*/{0},
556 /*target_dim_num=*/0));
557 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
558 /*dynamic_size_param_index=*/{2},
559 /*target_param_num=*/0,
560 /*target_param_index=*/{1},
561 /*target_dim_num=*/0));
562 auto gte0 = GetTupleElement(p0, 0);
563 auto gte1 = GetTupleElement(p0, 1);
564 Add(gte0, gte1, {0});
565 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
566 const Shape& result_shape =
567 module->entry_computation()->root_instruction()->shape();
568 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
569 << result_shape;
570 }
571
TEST_F(XlaBuilderTest,DynamicBroadcast)572 TEST_F(XlaBuilderTest, DynamicBroadcast) {
573 XlaBuilder b(TestName());
574 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
575 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
576 ShapeUtil::MakeShape(U32, {})});
577 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
578 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
579 /*dynamic_size_param_index=*/{1},
580 /*target_param_num=*/0,
581 /*target_param_index=*/{0},
582 /*target_dim_num=*/0));
583 auto gte = GetTupleElement(p0, 0);
584 BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4},
585 /*broadcast_dimensions=*/{1, 2});
586 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
587 const Shape& result_shape =
588 module->entry_computation()->root_instruction()->shape();
589 EXPECT_TRUE(
590 ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
591 << result_shape;
592 }
593
TEST_F(XlaBuilderTest,DynamicBinaryHasDegenerateBroadcast)594 TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) {
595 XlaBuilder b(TestName());
596 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
597 {ShapeUtil::MakeShape(F32, {10}, {true}),
598 ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})});
599 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
600 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
601 /*dynamic_size_param_index=*/{1},
602 /*target_param_num=*/0,
603 /*target_param_index=*/{0},
604 /*target_dim_num=*/0));
605 auto gte0 = GetTupleElement(p0, 0);
606 auto gte1 = GetTupleElement(p0, 1);
607 Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15]
608 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
609 const Shape& result_shape =
610 module->entry_computation()->root_instruction()->shape();
611 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
612 << result_shape;
613 }
614
TEST_F(XlaBuilderTest,DynamicSelectOnlyPredDynamic)615 TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) {
616 XlaBuilder b(TestName());
617 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
618 {ShapeUtil::MakeShape(PRED, {10}, {true}),
619 ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})});
620 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
621 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
622 /*dynamic_size_param_index=*/{1},
623 /*target_param_num=*/0,
624 /*target_param_index=*/{0},
625 /*target_dim_num=*/0));
626 auto gte0 = GetTupleElement(p0, 0);
627 auto gte1 = GetTupleElement(p0, 1);
628
629 Select(gte0, gte1, gte1);
630
631 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
632 const Shape& result_shape =
633 module->entry_computation()->root_instruction()->shape();
634 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true}))
635 << result_shape;
636 }
637
TEST_F(XlaBuilderTest,DynamicPad)638 TEST_F(XlaBuilderTest, DynamicPad) {
639 XlaBuilder b(TestName());
640 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
641 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
642 ShapeUtil::MakeShape(U32, {})});
643 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
644 auto pad_val = ConstantR0<float>(&b, -1);
645 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
646 /*dynamic_size_param_index=*/{1},
647 /*target_param_num=*/0,
648 /*target_param_index=*/{0},
649 /*target_dim_num=*/0));
650 auto gte = GetTupleElement(p0, 0);
651 PaddingConfig padding_config;
652 for (int i = 0; i < 2; i++) {
653 auto dimension = padding_config.add_dimensions();
654 dimension->set_edge_padding_low(0);
655 dimension->set_edge_padding_high(0);
656 dimension->set_interior_padding(0);
657 }
658 Pad(gte, pad_val, padding_config);
659 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
660 const Shape& result_shape =
661 module->entry_computation()->root_instruction()->shape();
662 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
663 << result_shape;
664 }
665
TEST_F(XlaBuilderTest,DynamicConvolution)666 TEST_F(XlaBuilderTest, DynamicConvolution) {
667 XlaBuilder b(TestName());
668 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
669 {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}),
670 ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}),
671 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
672 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
673 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
674 /*dynamic_size_param_index=*/{2},
675 /*target_param_num=*/0,
676 /*target_param_index=*/{0},
677 /*target_dim_num=*/0));
678 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
679 /*dynamic_size_param_index=*/{3},
680 /*target_param_num=*/0,
681 /*target_param_index=*/{1},
682 /*target_dim_num=*/2));
683 auto input = GetTupleElement(p0, 0);
684 auto filter = GetTupleElement(p0, 1);
685 ConvolutionDimensionNumbers dnums;
686 dnums.set_input_batch_dimension(0);
687 dnums.set_output_batch_dimension(0);
688 dnums.add_input_spatial_dimensions(1);
689 dnums.add_output_spatial_dimensions(1);
690 dnums.add_input_spatial_dimensions(2);
691 dnums.add_output_spatial_dimensions(2);
692 dnums.set_input_feature_dimension(3);
693 dnums.set_output_feature_dimension(3);
694 dnums.add_kernel_spatial_dimensions(0);
695 dnums.add_kernel_spatial_dimensions(1);
696 dnums.set_kernel_input_feature_dimension(2);
697 dnums.set_kernel_output_feature_dimension(3);
698 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
699 /*feature_group_count=*/1);
700 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
701 const Shape& result_shape =
702 module->entry_computation()->root_instruction()->shape();
703 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
704 {true, false, false, false}))
705 << result_shape;
706 }
707
TEST_F(XlaBuilderTest,DynamicDot)708 TEST_F(XlaBuilderTest, DynamicDot) {
709 XlaBuilder b(TestName());
710 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
711 {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}),
712 ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}),
713 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
714 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
715 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
716 /*dynamic_size_param_index=*/{2},
717 /*target_param_num=*/0,
718 /*target_param_index=*/{0},
719 /*target_dim_num=*/0));
720 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
721 /*dynamic_size_param_index=*/{2},
722 /*target_param_num=*/0,
723 /*target_param_index=*/{1},
724 /*target_dim_num=*/0));
725 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
726 /*dynamic_size_param_index=*/{3},
727 /*target_param_num=*/0,
728 /*target_param_index=*/{0},
729 /*target_dim_num=*/1));
730
731 auto lhs = GetTupleElement(p0, 0);
732 auto rhs = GetTupleElement(p0, 1);
733 DotDimensionNumbers dnums;
734 dnums.add_lhs_contracting_dimensions(2);
735 dnums.add_rhs_contracting_dimensions(1);
736 dnums.add_lhs_batch_dimensions(0);
737 dnums.add_rhs_batch_dimensions(0);
738 DotGeneral(lhs, rhs, dnums);
739 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
740 const Shape& result_shape =
741 module->entry_computation()->root_instruction()->shape();
742 EXPECT_TRUE(
743 ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false}))
744 << result_shape;
745 }
746
TEST_F(XlaBuilderTest,DynamicReduce)747 TEST_F(XlaBuilderTest, DynamicReduce) {
748 XlaBuilder b(TestName());
749 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
750 {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}),
751 ShapeUtil::MakeShape(U32, {})});
752 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
753 auto init = ConstantR0<float>(&b, 0);
754 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
755 /*dynamic_size_param_index=*/{1},
756 /*target_param_num=*/0,
757 /*target_param_index=*/{0},
758 /*target_dim_num=*/1));
759 auto gte = GetTupleElement(p0, 0);
760 XlaBuilder bsum(TestName());
761 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
762 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
763 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
764 Reduce(gte, init, sum, {0});
765 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
766 const Shape& result_shape =
767 module->entry_computation()->root_instruction()->shape();
768 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
769 << result_shape;
770 }
771
TEST_F(XlaBuilderTest,DynamicReduceWindow)772 TEST_F(XlaBuilderTest, DynamicReduceWindow) {
773 XlaBuilder b(TestName());
774 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
775 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
776 ShapeUtil::MakeShape(U32, {})});
777 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
778 auto init = ConstantR0<float>(&b, 0.f);
779 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
780 /*dynamic_size_param_index=*/{1},
781 /*target_param_num=*/0,
782 /*target_param_index=*/{0},
783 /*target_dim_num=*/0));
784 auto gte = GetTupleElement(p0, 0);
785 XlaBuilder bsum(TestName());
786 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
787 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
788 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
789 ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4},
790 /*window_strides=*/{1, 1, 1}, Padding::kValid);
791 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
792 const Shape& result_shape =
793 module->entry_computation()->root_instruction()->shape();
794 EXPECT_TRUE(
795 ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
796 << result_shape;
797 }
798
TEST_F(XlaBuilderTest,DynamicSelectAndScatter)799 TEST_F(XlaBuilderTest, DynamicSelectAndScatter) {
800 XlaBuilder b(TestName());
801 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
802 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
803 ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}),
804 ShapeUtil::MakeShape(U32, {})});
805 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
806 auto init = ConstantR0<float>(&b, 0.f);
807 XlaBuilder bsum(TestName());
808 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
809 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
810 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
811 XlaBuilder bge(TestName());
812 Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"),
813 Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y"));
814 TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build());
815
816 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
817 /*dynamic_size_param_index=*/{2},
818 /*target_param_num=*/0,
819 /*target_param_index=*/{0},
820 /*target_dim_num=*/0));
821 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
822 /*dynamic_size_param_index=*/{2},
823 /*target_param_num=*/0,
824 /*target_param_index=*/{1},
825 /*target_dim_num=*/0));
826 auto gte0 = GetTupleElement(p0, 0);
827 auto source = GetTupleElement(p0, 1);
828 SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source,
829 init, sum);
830 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
831 const Shape& result_shape =
832 module->entry_computation()->root_instruction()->shape();
833 EXPECT_TRUE(
834 ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
835 << result_shape;
836 }
837
TEST_F(XlaBuilderTest,DynamicReshape)838 TEST_F(XlaBuilderTest, DynamicReshape) {
839 XlaBuilder b(TestName());
840 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
841 {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6},
842 {false, false, true, true, false}),
843 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
844 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
845 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
846 /*dynamic_size_param_index=*/{1},
847 /*target_param_num=*/0,
848 /*target_param_index=*/{0},
849 /*target_dim_num=*/2));
850 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
851 /*dynamic_size_param_index=*/{2},
852 /*target_param_num=*/0,
853 /*target_param_index=*/{0},
854 /*target_dim_num=*/3));
855 auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6]
856 Reshape(gte, /*new_sizes=*/{6, 4, 1, 5, 2, 3});
857 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
858 const Shape& result_shape =
859 module->entry_computation()->root_instruction()->shape();
860 EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
861 EXPECT_TRUE(result_shape.is_dynamic_dimension(3));
862 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
863 {false, true, false, true, false, false}))
864 << result_shape;
865 }
866
TEST_F(XlaBuilderTest,DynamicSelect)867 TEST_F(XlaBuilderTest, DynamicSelect) {
868 XlaBuilder b(TestName());
869 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
870 {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
871 ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
872 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
873 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
874 auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
875 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
876 /*dynamic_size_param_index=*/{2},
877 /*target_param_num=*/0,
878 /*target_param_index=*/{0},
879 /*target_dim_num=*/1));
880 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
881 /*dynamic_size_param_index=*/{3},
882 /*target_param_num=*/0,
883 /*target_param_index=*/{1},
884 /*target_dim_num=*/1));
885 auto gte0 = GetTupleElement(p0, 0);
886 auto gte1 = GetTupleElement(p0, 1);
887 Select(pred, gte0, gte1);
888 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
889 const Shape& result_shape =
890 module->entry_computation()->root_instruction()->shape();
891 EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
892 EXPECT_FALSE(result_shape.is_dynamic_dimension(2));
893 EXPECT_TRUE(
894 ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
895 << result_shape;
896 }
897
TEST_F(XlaBuilderTest,DynamicSelectNotCompatible)898 TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
899 XlaBuilder b(TestName());
900 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
901 {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
902 ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}),
903 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
904 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
905 auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
906 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
907 /*dynamic_size_param_index=*/{2},
908 /*target_param_num=*/0,
909 /*target_param_index=*/{0},
910 /*target_dim_num=*/1));
911 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
912 /*dynamic_size_param_index=*/{3},
913 /*target_param_num=*/0,
914 /*target_param_index=*/{1},
915 /*target_dim_num=*/2));
916 auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6]
917 auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6]
918 Select(pred, gte0, gte1);
919 Status status = BuildHloModule(&b).status();
920 ASSERT_IS_NOT_OK(status);
921 EXPECT_THAT(status.error_message(),
922 ::testing::HasSubstr("Operands to select must be the same shape; "
923 "got f32[4,<=5,6] and f32[4,5,<=6]"));
924 }
925
TEST_F(XlaBuilderTest,DynamicTranspose)926 TEST_F(XlaBuilderTest, DynamicTranspose) {
927 XlaBuilder b(TestName());
928 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
929 {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}),
930 ShapeUtil::MakeShape(U32, {})});
931 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
932 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
933 /*dynamic_size_param_index=*/{1},
934 /*target_param_num=*/0,
935 /*target_param_index=*/{0},
936 /*target_dim_num=*/0));
937 auto gte = GetTupleElement(p0, 0);
938 Transpose(gte, /*permutation=*/{1, 0});
939 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
940 const Shape& result_shape =
941 module->entry_computation()->root_instruction()->shape();
942 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true}))
943 << result_shape;
944 }
945
TEST_F(XlaBuilderTest,AfterAllWithNonTokenOperands)946 TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
947 XlaBuilder b(TestName());
948 AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
949 Status status = b.Build().status();
950 ASSERT_IS_NOT_OK(status);
951 EXPECT_THAT(status.error_message(),
952 ::testing::HasSubstr("All operands to AfterAll must be tokens"));
953 }
954
TEST_F(XlaBuilderTest,CheckInputOutputAlias)955 TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
956 XlaBuilder b(TestName());
957 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0");
958 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1");
959 auto add = Add(p0, p1);
960 auto sub = Sub(p0, p1);
961 auto root = Tuple(&b, {add, sub});
962
963 b.SetUpAlias({1}, 0, {});
964 b.SetUpAlias({0}, 1, {});
965
966 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root));
967
968 const HloInputOutputAliasConfig& config = module->input_output_alias_config();
969 EXPECT_TRUE(config.ParameterHasAlias(0, {}));
970 EXPECT_TRUE(config.ParameterHasAlias(1, {}));
971
972 auto alias_p0 = config.GetAliasedOutput(0, {});
973 ASSERT_TRUE(alias_p0.has_value());
974 EXPECT_EQ(*alias_p0, ShapeIndex({1}));
975
976 auto alias_p1 = config.GetAliasedOutput(1, {});
977 ASSERT_TRUE(alias_p1.has_value());
978 EXPECT_EQ(*alias_p1, ShapeIndex({0}));
979 }
980
981 } // namespace
982 } // namespace xla
983