1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35
36 namespace op = xla::testing::opcode_matchers;
37
38 namespace xla {
39 namespace {
40
41 class DynamicDimensionInferenceTest : public HloTestBase {
42 protected:
DynamicDimensionInferenceTest()43 DynamicDimensionInferenceTest() : HloTestBase() {
44 module_ = CreateNewVerifiedModule();
45 }
46
RunInference(DynamicDimensionInference::CustomCallInferenceHandler handler=nullptr)47 Status RunInference(
48 DynamicDimensionInference::CustomCallInferenceHandler handler = nullptr) {
49 TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
50 DynamicDimensionInference::Run(module_.get(), handler));
51
52 inference_ = absl::make_unique<DynamicDimensionInference>(inference);
53 return Status::OK();
54 }
55
GetAdd()56 HloComputation* GetAdd() {
57 auto embedded_builder = HloComputation::Builder("add");
58 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
59 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
60 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
61 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
62 embedded_builder.AddInstruction(
63 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
64 return module_->AddEmbeddedComputation(embedded_builder.Build());
65 }
66
GetAddTuple()67 HloComputation* GetAddTuple() {
68 auto embedded_builder = HloComputation::Builder("add");
69 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
70 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
71 auto lhs_1 =
72 embedded_builder.AddInstruction(HloInstruction::CreateParameter(
73 1, ShapeUtil::MakeShape(F32, {}), "lhs.1"));
74 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
75 2, ShapeUtil::MakeShape(F32, {}), "rhs"));
76 auto rhs_1 =
77 embedded_builder.AddInstruction(HloInstruction::CreateParameter(
78 3, ShapeUtil::MakeShape(F32, {}), "rhs.1"));
79 auto add = embedded_builder.AddInstruction(
80 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
81 auto add_1 = embedded_builder.AddInstruction(HloInstruction::CreateBinary(
82 lhs->shape(), HloOpcode::kAdd, lhs_1, rhs_1));
83 embedded_builder.AddInstruction(HloInstruction::CreateTuple({add, add_1}));
84 return module_->AddEmbeddedComputation(embedded_builder.Build());
85 }
86
GetGe()87 HloComputation* GetGe() {
88 auto embedded_builder = HloComputation::Builder("ge");
89 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
90 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
91 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
92 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
93 embedded_builder.AddInstruction(HloInstruction::CreateCompare(
94 ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
95 return module_->AddEmbeddedComputation(embedded_builder.Build());
96 }
97
98 std::unique_ptr<HloModule> module_;
99 std::unique_ptr<DynamicDimensionInference> inference_;
100 const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
101 };
102
TEST_F(DynamicDimensionInferenceTest,ParamTest)103 TEST_F(DynamicDimensionInferenceTest, ParamTest) {
104 auto builder = HloComputation::Builder(TestName());
105 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
106
107 auto param = builder.AddInstruction(
108 HloInstruction::CreateParameter(0, input_shape, "param"));
109 auto param2 = builder.AddInstruction(
110 HloInstruction::CreateParameter(1, scalar_shape_, "param"));
111
112 module_->AddEntryComputation(builder.Build());
113 SCOPED_TRACE(module_->ToString());
114
115 // Set up dynamic parameter binding.
116 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
117 DynamicParameterBinding::DynamicParameter{1, {}},
118 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
119
120 TF_ASSERT_OK(RunInference());
121 EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2);
122 EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr);
123 EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr);
124 }
125
TEST_F(DynamicDimensionInferenceTest,ParamTestTuple)126 TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) {
127 auto builder = HloComputation::Builder(TestName());
128 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
129
130 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
131 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
132
133 module_->AddEntryComputation(builder.Build());
134 // Set up dynamic parameter binding.
135 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
136 DynamicParameterBinding::DynamicParameter{0, {1}},
137 DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
138
139 SCOPED_TRACE(module_->ToString());
140 TF_ASSERT_OK(RunInference());
141 EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
142 op::GetTupleElement(param, 1));
143
144 EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
145 }
146
TEST_F(DynamicDimensionInferenceTest,GetTupleElement)147 TEST_F(DynamicDimensionInferenceTest, GetTupleElement) {
148 // When data flows through GTE, the dynamic dimension size keeps the
149 // same, and the index has its front popped.
150 auto builder = HloComputation::Builder(TestName());
151 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
152
153 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
154 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
155
156 auto gte = builder.AddInstruction(
157 HloInstruction::CreateGetTupleElement(input_shape, param, 0));
158
159 module_->AddEntryComputation(builder.Build());
160 // Set up dynamic parameter binding.
161 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
162 DynamicParameterBinding::DynamicParameter{0, {1}},
163 DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
164
165 SCOPED_TRACE(module_->ToString());
166 TF_ASSERT_OK(RunInference());
167 EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
168 op::GetTupleElement(param, 1));
169
170 EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1),
171 op::GetTupleElement(param, 1));
172
173 EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
174 }
175
TEST_F(DynamicDimensionInferenceTest,ElementwiseTest)176 TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) {
177 // When data flows through elementwise, the dynamic dimension size keeps the
178 // same.
179 auto builder = HloComputation::Builder(TestName());
180 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
181
182 auto data_param = builder.AddInstruction(
183 HloInstruction::CreateParameter(0, input_shape, "data_param"));
184 auto size_param = builder.AddInstruction(
185 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
186
187 auto* negate = builder.AddInstruction(
188 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
189
190 module_->AddEntryComputation(builder.Build());
191 // Set up dynamic parameter binding.
192 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
193 DynamicParameterBinding::DynamicParameter{1, {}},
194 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
195
196 SCOPED_TRACE(module_->ToString());
197 TF_ASSERT_OK(RunInference());
198 EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param);
199 }
200
TEST_F(DynamicDimensionInferenceTest,ReduceTestI)201 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
202 auto builder = HloComputation::Builder(TestName());
203 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
204 auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
205
206 auto data_param = builder.AddInstruction(
207 HloInstruction::CreateParameter(0, input_shape, "data_param"));
208 auto size_param = builder.AddInstruction(
209 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
210
211 auto negate = builder.AddInstruction(
212 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
213
214 auto init = builder.AddInstruction(
215 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
216
217 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
218 reduce_shape, negate, init, {0, 2}, GetAdd()));
219
220 module_->AddEntryComputation(builder.Build());
221
222 // Set up dynamic parameter binding.
223 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
224 DynamicParameterBinding::DynamicParameter{1, {}},
225 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
226
227 SCOPED_TRACE(module_->ToString());
228 TF_ASSERT_OK(RunInference());
229 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param);
230 }
231
TEST_F(DynamicDimensionInferenceTest,ReduceTestII)232 TEST_F(DynamicDimensionInferenceTest, ReduceTestII) {
233 // Same as ReduceTestI, but only reduce one dimension.
234 auto builder = HloComputation::Builder(TestName());
235 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
236 auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
237
238 auto data_param = builder.AddInstruction(
239 HloInstruction::CreateParameter(0, input_shape, "data_param"));
240 auto size_param = builder.AddInstruction(
241 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
242
243 auto negate = builder.AddInstruction(
244 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
245
246 auto init = builder.AddInstruction(
247 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
248
249 auto reduce = builder.AddInstruction(
250 HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd()));
251
252 module_->AddEntryComputation(builder.Build());
253
254 // Set up dynamic parameter binding.
255 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
256 DynamicParameterBinding::DynamicParameter{1, {}},
257 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
258
259 SCOPED_TRACE(module_->ToString());
260 TF_ASSERT_OK(RunInference());
261 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param);
262 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr);
263 }
264
TEST_F(DynamicDimensionInferenceTest,VariadicReduce)265 TEST_F(DynamicDimensionInferenceTest, VariadicReduce) {
266 // Handle variadic reduce where output is a tuple.
267 auto builder = HloComputation::Builder(TestName());
268 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
269 auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
270
271 auto data_param_dynamic = builder.AddInstruction(
272 HloInstruction::CreateParameter(0, input_shape, "data_param"));
273 auto data_param_static = builder.AddInstruction(
274 HloInstruction::CreateParameter(1, input_shape, "data_param.2"));
275 auto size_param = builder.AddInstruction(
276 HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
277
278 // Set up dynamic parameter binding.
279 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
280 DynamicParameterBinding::DynamicParameter{2, {}},
281 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
282
283 auto dynamic_negate = builder.AddInstruction(HloInstruction::CreateUnary(
284 input_shape, HloOpcode::kNegate, data_param_dynamic));
285
286 auto static_negate = builder.AddInstruction(HloInstruction::CreateUnary(
287 input_shape, HloOpcode::kNegate, data_param_static));
288
289 auto init = builder.AddInstruction(
290 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
291
292 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
293 ShapeUtil::MakeTupleShape({reduce_shape, reduce_shape}),
294 {dynamic_negate, static_negate}, {init, init}, {1}, GetAddTuple()));
295
296 module_->AddEntryComputation(builder.Build());
297
298 SCOPED_TRACE(module_->ToString());
299 TF_ASSERT_OK(RunInference());
300 EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 1), size_param);
301 EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 1), size_param);
302 EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 0), nullptr);
303 EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 0), nullptr);
304 }
305
TEST_F(DynamicDimensionInferenceTest,DotTest)306 TEST_F(DynamicDimensionInferenceTest, DotTest) {
307 auto builder = HloComputation::Builder(TestName());
308 constexpr int xdim = 3;
309 constexpr int ydim = 2;
310 constexpr int zdim = 1;
311 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
312 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
313 auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
314
315 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
316 /*parameter_number=*/0, xy_shape, "A"));
317 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
318 /*parameter_number=*/1, yz_shape, "B"));
319 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
320 /*parameter_number=*/2, scalar_shape_, "size_param"));
321
322 DotDimensionNumbers dot_dnums;
323 dot_dnums.add_lhs_contracting_dimensions(1);
324 dot_dnums.add_rhs_contracting_dimensions(0);
325 auto dot = builder.AddInstruction(
326 HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
327 HloTestBase::DefaultPrecisionConfig(2)));
328
329 module_->AddEntryComputation(builder.Build());
330
331 // Set up dynamic parameter binding for non-contracting dimension.
332 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
333 DynamicParameterBinding::DynamicParameter{2, {}},
334 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
335
336 // Set up binding for contracting dimensions.
337 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
338 DynamicParameterBinding::DynamicParameter{2, {}},
339 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
340 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
341 DynamicParameterBinding::DynamicParameter{2, {}},
342 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
343
344 SCOPED_TRACE(module_->ToString());
345 TF_ASSERT_OK(RunInference());
346 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
347 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
348 }
349
TEST_F(DynamicDimensionInferenceTest,DotTestBatch)350 TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
351 auto builder = HloComputation::Builder(TestName());
352 auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
353 auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
354 auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
355
356 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
357 /*parameter_number=*/0, lhs_shape, "A"));
358 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
359 /*parameter_number=*/1, rhs_shape, "B"));
360 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
361 /*parameter_number=*/2, scalar_shape_, "size_param"));
362
363 DotDimensionNumbers dot_dnums;
364 dot_dnums.add_lhs_contracting_dimensions(3);
365 dot_dnums.add_rhs_contracting_dimensions(3);
366 dot_dnums.add_lhs_batch_dimensions(0);
367 dot_dnums.add_lhs_batch_dimensions(2);
368 dot_dnums.add_rhs_batch_dimensions(0);
369 dot_dnums.add_rhs_batch_dimensions(2);
370 auto dot = builder.AddInstruction(
371 HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
372 HloTestBase::DefaultPrecisionConfig(2)));
373
374 module_->AddEntryComputation(builder.Build());
375
376 // Set up dynamic parameter binding for batch dimension.
377 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
378 DynamicParameterBinding::DynamicParameter{2, {}},
379 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
380
381 SCOPED_TRACE(module_->ToString());
382 TF_ASSERT_OK(RunInference());
383 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
384 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
385 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
386 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
387 }
388
TEST_F(DynamicDimensionInferenceTest,DotTestMultiContracting)389 TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
390 auto builder = HloComputation::Builder(TestName());
391 auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
392 auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
393 auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
394
395 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
396 /*parameter_number=*/0, lhs_shape, "A"));
397 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
398 /*parameter_number=*/1, rhs_shape, "B"));
399 builder.AddInstruction(HloInstruction::CreateParameter(
400 /*parameter_number=*/2, scalar_shape_, "size_param"));
401
402 DotDimensionNumbers dot_dnums;
403 dot_dnums.add_lhs_contracting_dimensions(0);
404 dot_dnums.add_lhs_contracting_dimensions(1);
405 dot_dnums.add_rhs_contracting_dimensions(0);
406 dot_dnums.add_rhs_contracting_dimensions(1);
407 auto dot = builder.AddInstruction(
408 HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
409 HloTestBase::DefaultPrecisionConfig(2)));
410
411 module_->AddEntryComputation(builder.Build());
412
413 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
414 DynamicParameterBinding::DynamicParameter{2, {}},
415 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
416
417 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
418 DynamicParameterBinding::DynamicParameter{2, {}},
419 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
420 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
421 DynamicParameterBinding::DynamicParameter{2, {}},
422 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
423
424 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
425 DynamicParameterBinding::DynamicParameter{2, {}},
426 DynamicParameterBinding::DynamicDimension{1, {}, 1}));
427
428 SCOPED_TRACE(module_->ToString());
429 TF_ASSERT_OK(RunInference());
430 // Nothing is dynamic in the output.
431 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
432 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
433 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
434 }
435
TEST_F(DynamicDimensionInferenceTest,ConvolutionTest)436 TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
437 auto builder = HloComputation::Builder(TestName());
438 constexpr int xdim = 3;
439 constexpr int ydim = 2;
440 constexpr int zdim = 1;
441 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
442 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
443 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
444
445 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
446 /*parameter_number=*/0, xy_shape, "A"));
447 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
448 /*parameter_number=*/1, yz_shape, "B"));
449 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
450 /*parameter_number=*/2, scalar_shape_, "size_param"));
451
452 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
453
454 dnums.set_kernel_input_feature_dimension(0);
455 dnums.set_kernel_output_feature_dimension(1);
456 dnums.set_input_batch_dimension(0);
457 dnums.set_output_batch_dimension(1);
458 dnums.set_output_feature_dimension(0);
459
460 Window window;
461
462 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
463 zx_shape, a_param, b_param, /*feature_group_count=*/1,
464 /*batch_group_count=*/1, window, dnums,
465 HloTestBase::DefaultPrecisionConfig(2)));
466
467 module_->AddEntryComputation(builder.Build());
468
469 // Set up dynamic parameter binding for non-contracting dimension.
470 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
471 DynamicParameterBinding::DynamicParameter{2, {}},
472 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
473
474 // Set up binding for contracting dimensions.
475 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
476 DynamicParameterBinding::DynamicParameter{2, {}},
477 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
478
479 SCOPED_TRACE(module_->ToString());
480 TF_ASSERT_OK(RunInference());
481 EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param);
482 EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr);
483 }
484
TEST_F(DynamicDimensionInferenceTest,TransposeTest)485 TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
486 // Test the ability to trace unmodified dimensions
487 auto builder = HloComputation::Builder(TestName());
488 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
489 auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
490
491 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
492 /*parameter_number=*/0, input_shape, "A"));
493 auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
494 /*parameter_number=*/1, scalar_shape_, "size_param"));
495 auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
496 /*parameter_number=*/2, scalar_shape_, "size_param"));
497 auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
498 /*parameter_number=*/3, scalar_shape_, "size_param"));
499
500 auto* transpose = builder.AddInstruction(
501 HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0}));
502
503 module_->AddEntryComputation(builder.Build());
504
505 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
506 DynamicParameterBinding::DynamicParameter{1, {}},
507 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
508
509 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
510 DynamicParameterBinding::DynamicParameter{2, {}},
511 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
512
513 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
514 DynamicParameterBinding::DynamicParameter{3, {}},
515 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
516
517 SCOPED_TRACE(module_->ToString());
518 TF_ASSERT_OK(RunInference());
519 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
520 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2);
521 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
522 }
523
TEST_F(DynamicDimensionInferenceTest,NonDescendingTransposeTest)524 TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
525 // Test the ability to trace unmodified dimensions
526 auto builder = HloComputation::Builder(TestName());
527 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
528 auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
529
530 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
531 /*parameter_number=*/0, input_shape, "A"));
532 auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
533 /*parameter_number=*/1, scalar_shape_, "size_param"));
534 auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
535 /*parameter_number=*/2, scalar_shape_, "size_param"));
536 auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
537 /*parameter_number=*/3, scalar_shape_, "size_param"));
538
539 auto* transpose = builder.AddInstruction(
540 HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
541
542 module_->AddEntryComputation(builder.Build());
543
544 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
545 DynamicParameterBinding::DynamicParameter{1, {}},
546 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
547
548 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
549 DynamicParameterBinding::DynamicParameter{2, {}},
550 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
551
552 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
553 DynamicParameterBinding::DynamicParameter{3, {}},
554 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
555
556 SCOPED_TRACE(module_->ToString());
557 TF_ASSERT_OK(RunInference());
558 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
559 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
560 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
561 }
562
TEST_F(DynamicDimensionInferenceTest,ReshapeTest)563 TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
564 // Test the ability to trace unmodified reshape dimensions.
565 auto builder = HloComputation::Builder(TestName());
566 auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
567 auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
568
569 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
570 /*parameter_number=*/0, input_shape, "A"));
571 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
572 /*parameter_number=*/1, scalar_shape_, "size_param"));
573
574 auto* reshape = builder.AddInstruction(
575 HloInstruction::CreateReshape(output_shape, a_param));
576
577 module_->AddEntryComputation(builder.Build());
578
579 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
580 DynamicParameterBinding::DynamicParameter{1, {}},
581 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
582
583 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
584 DynamicParameterBinding::DynamicParameter{1, {}},
585 DynamicParameterBinding::DynamicDimension{0, {}, 3}));
586
587 SCOPED_TRACE(module_->ToString());
588 TF_ASSERT_OK(RunInference());
589 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
590 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param);
591 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr);
592 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param);
593 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr);
594 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr);
595 }
596
TEST_F(DynamicDimensionInferenceTest,ReshapeInferredDimensionTest)597 TEST_F(DynamicDimensionInferenceTest, ReshapeInferredDimensionTest) {
598 // Test the ability to trace inferred dimension when output is bigger than
599 // input.
600 auto builder = HloComputation::Builder(TestName());
601 auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
602 auto output_shape = ShapeUtil::MakeShape(F32, {1, 4, 5});
603
604 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
605 /*parameter_number=*/0, input_shape, "A"));
606 builder.AddInstruction(HloInstruction::CreateParameter(
607 /*parameter_number=*/1, scalar_shape_, "size_param"));
608
609 auto* reshape = builder.AddInstruction(HloInstruction::CreateReshape(
610 output_shape, a_param, /*inferred_dimension=*/0));
611
612 module_->AddEntryComputation(builder.Build());
613
614 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
615 DynamicParameterBinding::DynamicParameter{1, {}},
616 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
617
618 SCOPED_TRACE(module_->ToString());
619 TF_ASSERT_OK(RunInference());
620 EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
621 }
622
TEST_F(DynamicDimensionInferenceTest,ReshapeTestMajorDimension)623 TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
624 // Test the ability to trace dimension combining.
625 auto builder = HloComputation::Builder(TestName());
626 auto input_shape = ShapeUtil::MakeShape(F32, {32, 10, 4});
627 auto output_shape = ShapeUtil::MakeShape(F32, {320, 4});
628
629 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
630 /*parameter_number=*/0, input_shape, "A"));
631
632 builder.AddInstruction(HloInstruction::CreateParameter(
633 /*parameter_number=*/1, scalar_shape_, "size_param"));
634
635 auto* reshape = builder.AddInstruction(
636 HloInstruction::CreateReshape(output_shape, a_param));
637
638 module_->AddEntryComputation(builder.Build());
639
640 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
641 DynamicParameterBinding::DynamicParameter{1, {}},
642 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
643
644 SCOPED_TRACE(module_->ToString());
645 Status status = RunInference();
646 EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
647 }
648
TEST_F(DynamicDimensionInferenceTest,ReshapeIntoScalar)649 TEST_F(DynamicDimensionInferenceTest, ReshapeIntoScalar) {
650 // Test the ability to a reshape into scalar.
651 auto builder = HloComputation::Builder(TestName());
652 auto input_shape = ShapeUtil::MakeShape(F32, {1});
653 auto output_shape = ShapeUtil::MakeShape(F32, {});
654
655 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
656 /*parameter_number=*/0, input_shape, "A"));
657
658 builder.AddInstruction(HloInstruction::CreateParameter(
659 /*parameter_number=*/1, scalar_shape_, "size_param"));
660
661 builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param));
662
663 module_->AddEntryComputation(builder.Build());
664
665 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
666 DynamicParameterBinding::DynamicParameter{1, {}},
667 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
668
669 SCOPED_TRACE(module_->ToString());
670 TF_CHECK_OK(RunInference());
671 }
672
TEST_F(DynamicDimensionInferenceTest,GatherTest)673 TEST_F(DynamicDimensionInferenceTest, GatherTest) {
674 const string hlo_text = R"(
675 HloModule TensorFlowGatherV2
676
677 ENTRY main {
678 operand = s32[20,10]{1,0} parameter(0)
679 indices = s32[32,20] parameter(1)
680 dynamic_size = s32[] parameter(2)
681 ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices),
682 offset_dims={2},
683 collapsed_slice_dims={0},
684 start_index_map={0},
685 index_vector_dim=2,
686 slice_sizes={1,10}
687 }
688 )";
689
690 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
691 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
692 DynamicParameterBinding::DynamicParameter{2, {}},
693 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
694 SCOPED_TRACE(module_->ToString());
695 TF_ASSERT_OK(RunInference());
696 EXPECT_EQ(inference_->GetDynamicSize(
697 module_->entry_computation()->root_instruction(), {}, 0),
698 module_->entry_computation()->parameter_instruction(2));
699 }
700
TEST_F(DynamicDimensionInferenceTest,BroadcastTest)701 TEST_F(DynamicDimensionInferenceTest, BroadcastTest) {
702 // Test the ability to trace broadcast dimension.
703 auto builder = HloComputation::Builder(TestName());
704 auto input_shape = ShapeUtil::MakeShape(F32, {2});
705 auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
706
707 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
708 /*parameter_number=*/0, input_shape, "A"));
709 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
710 /*parameter_number=*/1, scalar_shape_, "size_param"));
711
712 auto* broadcast = builder.AddInstruction(
713 HloInstruction::CreateBroadcast(output_shape, a_param, {1}));
714
715 module_->AddEntryComputation(builder.Build());
716
717 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
718 DynamicParameterBinding::DynamicParameter{1, {}},
719 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
720
721 SCOPED_TRACE(module_->ToString());
722 TF_ASSERT_OK(RunInference());
723 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr);
724 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param);
725 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr);
726 }
727
TEST_F(DynamicDimensionInferenceTest,WhileTest)728 TEST_F(DynamicDimensionInferenceTest, WhileTest) {
729 // Test the ability to trace into while loops.
730 auto builder = HloComputation::Builder(TestName());
731 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
732 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
733 auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
734
735 // Body:
736 //
737 // Param
738 // | |
739 // GTE1 GTE2
740 // | |
741 // ADD
742 auto body_builder = HloComputation::Builder("body");
743 auto body_param = body_builder.AddInstruction(
744 HloInstruction::CreateParameter(0, tuple_shape, "param"));
745 auto gte_0 = body_builder.AddInstruction(
746 HloInstruction::CreateGetTupleElement(input_shape, body_param, 0));
747 auto gte_1 = body_builder.AddInstruction(
748 HloInstruction::CreateGetTupleElement(input_shape, body_param, 1));
749 auto add = body_builder.AddInstruction(
750 HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1));
751 body_builder.AddInstruction(HloInstruction::CreateTuple({add, add}));
752
753 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
754
755 auto cond_builder = HloComputation::Builder("condition");
756 cond_builder.AddInstruction(
757 HloInstruction::CreateParameter(0, tuple_shape, "param"));
758 cond_builder.AddInstruction(
759 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
760 HloComputation* condition =
761 module_->AddEmbeddedComputation(cond_builder.Build());
762
763 // Entry:
764 //
765 // Param
766 // |
767 // While
768 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
769 /*parameter_number=*/0, tuple_shape, "A"));
770 builder.AddInstruction(HloInstruction::CreateParameter(
771 /*parameter_number=*/1, scalar_shape_, "size_param"));
772 builder.AddInstruction(
773 HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
774
775 module_->AddEntryComputation(builder.Build());
776
777 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
778 DynamicParameterBinding::DynamicParameter{1, {}},
779 DynamicParameterBinding::DynamicDimension{0, {0}, 0}));
780
781 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
782 DynamicParameterBinding::DynamicParameter{1, {}},
783 DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
784
785 TF_ASSERT_OK(RunInference());
786 HloInstruction* while_hlo = nullptr;
787 // The while hlo has been replaced, find the new one.
788 for (HloInstruction* inst : module_->entry_computation()->instructions()) {
789 if (inst->opcode() == HloOpcode::kWhile) {
790 while_hlo = inst;
791 }
792 }
793 ASSERT_NE(while_hlo, nullptr);
794 // The original while shape has 2 parameters. With dynamic size, the tuple
795 // should have 4 elements (We don't deduplicate the arguments).
796 EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
797 HloInstruction* add_inst = nullptr;
798 for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
799 if (inst->opcode() == HloOpcode::kAdd) {
800 add_inst = inst;
801 }
802 }
803 EXPECT_NE(add_inst, nullptr);
804 EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
805 EXPECT_NE(inference_->GetDynamicSize(
806 module_->entry_computation()->root_instruction(), {0}, 0),
807 nullptr);
808 EXPECT_NE(inference_->GetDynamicSize(
809 module_->entry_computation()->root_instruction(), {1}, 0),
810 nullptr);
811 }
812
TEST_F(DynamicDimensionInferenceTest,ConditionalInputTest)813 TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
814 // Test the ability to trace into contional loops.
815 auto builder = HloComputation::Builder(TestName());
816 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
817 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
818 // In this test we set inputs to different branches to different shapes.
819 auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape});
820 auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape});
821 auto tuple_shape_3 =
822 ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape});
823
824 // true branch:
825 //
826 // Param
827 // | |
828 // GTE1 GTE2
829 // | |
830 // Tuple(ADD)
831 auto true_builder = HloComputation::Builder("true");
832 {
833 auto true_param = true_builder.AddInstruction(
834 HloInstruction::CreateParameter(0, tuple_shape_2, "param"));
835 auto gte_0 = true_builder.AddInstruction(
836 HloInstruction::CreateGetTupleElement(input_shape, true_param, 0));
837 auto gte_1 = true_builder.AddInstruction(
838 HloInstruction::CreateGetTupleElement(input_shape, true_param, 1));
839 auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
840 input_shape, HloOpcode::kAdd, gte_0, gte_1));
841 true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
842 }
843 HloComputation* true_branch =
844 module_->AddEmbeddedComputation(true_builder.Build());
845 // false branch:
846 //
847 // Param
848 // | | |
849 // GTE1 GTE2 GTE3
850 // | |
851 // Tuple(ADD)
852 auto false_builder = HloComputation::Builder("false");
853 {
854 auto false_param = false_builder.AddInstruction(
855 HloInstruction::CreateParameter(0, tuple_shape_3, "param"));
856 auto gte_0 = false_builder.AddInstruction(
857 HloInstruction::CreateGetTupleElement(input_shape, false_param, 1));
858 auto gte_1 = false_builder.AddInstruction(
859 HloInstruction::CreateGetTupleElement(input_shape, false_param, 2));
860 auto add = false_builder.AddInstruction(HloInstruction::CreateBinary(
861 input_shape, HloOpcode::kAdd, gte_0, gte_1));
862 false_builder.AddInstruction(HloInstruction::CreateTuple({add}));
863 }
864 HloComputation* false_branch =
865 module_->AddEmbeddedComputation(false_builder.Build());
866
867 // Entry:
868 //
869 // Param(bool) Param2 (tuple_2) Param3(tuple_3)
870 // | | |
871 // +---------Condition------------+
872 auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
873 /*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred"));
874
875 auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter(
876 /*parameter_number=*/1, tuple_shape_2, "tuple_2_param"));
877 auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter(
878 /*parameter_number=*/2, tuple_shape_3, "tuple_3_param"));
879 builder.AddInstruction(HloInstruction::CreateParameter(
880 /*parameter_number=*/3, scalar_shape_, "size_param"));
881 builder.AddInstruction(HloInstruction::CreateConditional(
882 tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param,
883 false_branch));
884
885 module_->AddEntryComputation(builder.Build());
886
887 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
888 DynamicParameterBinding::DynamicParameter{3, {}},
889 DynamicParameterBinding::DynamicDimension{1, {0}, 0}));
890 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
891 DynamicParameterBinding::DynamicParameter{3, {}},
892 DynamicParameterBinding::DynamicDimension{1, {1}, 0}));
893 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
894 DynamicParameterBinding::DynamicParameter{3, {}},
895 DynamicParameterBinding::DynamicDimension{2, {1}, 0}));
896 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
897 DynamicParameterBinding::DynamicParameter{3, {}},
898 DynamicParameterBinding::DynamicDimension{2, {2}, 0}));
899
900 TF_ASSERT_OK(RunInference());
901
902 HloInstruction* conditional_hlo = nullptr;
903 // The while hlo has been replaced, find the new one.
904 for (HloInstruction* inst : module_->entry_computation()->instructions()) {
905 if (inst->opcode() == HloOpcode::kConditional) {
906 conditional_hlo = inst;
907 }
908 }
909 ASSERT_NE(conditional_hlo, nullptr);
910 // The original conditional shape has 1 parameters. With dynamic size passed
911 // out from the computation, another element is added to the tuple.
912 EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2);
913 HloInstruction* add_true_branch = nullptr;
914 for (HloInstruction* inst :
915 conditional_hlo->true_computation()->instructions()) {
916 if (inst->opcode() == HloOpcode::kAdd) {
917 add_true_branch = inst;
918 }
919 }
920 EXPECT_NE(add_true_branch, nullptr);
921 EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr);
922
923 HloInstruction* add_false_branch = nullptr;
924 for (HloInstruction* inst :
925 conditional_hlo->false_computation()->instructions()) {
926 if (inst->opcode() == HloOpcode::kAdd) {
927 add_false_branch = inst;
928 }
929 }
930 EXPECT_NE(add_false_branch, nullptr);
931 EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr);
932
933 EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr);
934 }
935
TEST_F(DynamicDimensionInferenceTest,ReduceWindowBatchTest)936 TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
937 // Test the ability to trace reduce window batch dimensions.
938 auto builder = HloComputation::Builder(TestName());
939 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
940 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
941
942 Window window;
943 // First dimension is unchanged.
944 WindowDimension* batch_dim = window.add_dimensions();
945 batch_dim->set_size(1);
946 batch_dim->set_stride(1);
947 batch_dim->set_padding_low(0);
948 batch_dim->set_padding_high(0);
949 batch_dim->set_window_dilation(1);
950 batch_dim->set_base_dilation(1);
951
952 // Second and third dimension are reduced.
953 for (int64 i = 0; i < 2; ++i) {
954 WindowDimension* dim = window.add_dimensions();
955 dim->set_size(2);
956 dim->set_stride(2);
957 dim->set_padding_low(0);
958 dim->set_padding_high(0);
959 dim->set_window_dilation(1);
960 dim->set_base_dilation(1);
961 }
962
963 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
964 /*parameter_number=*/0, input_shape, "A"));
965 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
966 /*parameter_number=*/1, scalar_shape_, "size_param"));
967
968 auto init = builder.AddInstruction(
969 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
970
971 auto* reduce_window =
972 builder.AddInstruction(HloInstruction::CreateReduceWindow(
973 output_shape, a_param, init, window, GetAdd()));
974
975 module_->AddEntryComputation(builder.Build());
976
977 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
978 DynamicParameterBinding::DynamicParameter{1, {}},
979 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
980
981 SCOPED_TRACE(module_->ToString());
982 TF_ASSERT_OK(RunInference());
983 EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param);
984 }
985
TEST_F(DynamicDimensionInferenceTest,SelectAndScatterTest)986 TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) {
987 // Test the ability to trace select and scatter batch dimensions.
988 auto builder = HloComputation::Builder(TestName());
989 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
990 auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
991
992 Window window;
993 // First dimension is unchanged.
994 WindowDimension* batch_dim = window.add_dimensions();
995 batch_dim->set_size(1);
996 batch_dim->set_stride(1);
997 batch_dim->set_padding_low(0);
998 batch_dim->set_padding_high(0);
999 batch_dim->set_window_dilation(1);
1000 batch_dim->set_base_dilation(1);
1001
1002 // Second and third dimension are reduced.
1003 for (int64 i = 0; i < 2; ++i) {
1004 WindowDimension* dim = window.add_dimensions();
1005 dim->set_size(2);
1006 dim->set_stride(2);
1007 dim->set_padding_low(0);
1008 dim->set_padding_high(0);
1009 dim->set_window_dilation(1);
1010 dim->set_base_dilation(1);
1011 }
1012
1013 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
1014 /*parameter_number=*/0, input_shape, "A"));
1015 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
1016 /*parameter_number=*/1, scalar_shape_, "size_param"));
1017 auto* source = builder.AddInstruction(HloInstruction::CreateParameter(
1018 /*parameter_number=*/2, source_shape, "B"));
1019
1020 auto init = builder.AddInstruction(
1021 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
1022
1023 auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
1024 input_shape, a_param, GetGe(), window, source, init, GetAdd()));
1025
1026 module_->AddEntryComputation(builder.Build());
1027
1028 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1029 DynamicParameterBinding::DynamicParameter{1, {}},
1030 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1031 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1032 DynamicParameterBinding::DynamicParameter{1, {}},
1033 DynamicParameterBinding::DynamicDimension{2, {}, 0}));
1034
1035 SCOPED_TRACE(module_->ToString());
1036 TF_ASSERT_OK(RunInference());
1037 EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param);
1038 }
1039
TEST_F(DynamicDimensionInferenceTest,ConcateTest)1040 TEST_F(DynamicDimensionInferenceTest, ConcateTest) {
1041 // Concat two data params.
1042 auto builder = HloComputation::Builder(TestName());
1043
1044 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1045 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param_1"));
1046 auto data_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
1047 1, ShapeUtil::MakeShape(F32, {5, 8}), "data_param_2"));
1048 auto size_param = builder.AddInstruction(
1049 HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
1050
1051 auto* concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
1052 ShapeUtil::MakeShape(F32, {5, 15}), {data_param, data_param_2}, 1));
1053
1054 module_->AddEntryComputation(builder.Build());
1055 // Set up dynamic parameter binding.
1056 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1057 DynamicParameterBinding::DynamicParameter{2, {}},
1058 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1059
1060 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1061 DynamicParameterBinding::DynamicParameter{2, {}},
1062 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
1063
1064 TF_ASSERT_OK(RunInference());
1065 EXPECT_EQ(inference_->GetDynamicSize(concat, {}, 0), size_param);
1066 }
1067
TEST_F(DynamicDimensionInferenceTest,SliceTest)1068 TEST_F(DynamicDimensionInferenceTest, SliceTest) {
1069 auto builder = HloComputation::Builder(TestName());
1070
1071 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1072 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1073 auto size_param = builder.AddInstruction(
1074 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1075
1076 auto* slice = builder.AddInstruction(HloInstruction::CreateSlice(
1077 ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0},
1078 /*limit_indices=*/{5, 7}, /*strides=*/{1, 1}));
1079
1080 module_->AddEntryComputation(builder.Build());
1081 // Set up dynamic parameter binding.
1082 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1083 DynamicParameterBinding::DynamicParameter{1, {}},
1084 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
1085
1086 TF_ASSERT_OK(RunInference());
1087 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param);
1088 }
1089
TEST_F(DynamicDimensionInferenceTest,DynamicSliceTest)1090 TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
1091 auto builder = HloComputation::Builder(TestName());
1092
1093 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1094 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1095 auto size_param = builder.AddInstruction(
1096 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1097
1098 std::vector<HloInstruction*> params;
1099 for (int i = 0; i < 2; ++i) {
1100 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1101 i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1102 }
1103
1104 auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1105 ShapeUtil::MakeShape(F32, {5, 1}), data_param, params,
1106 /*slice_sizes=*/{5, 1}));
1107
1108 module_->AddEntryComputation(builder.Build());
1109 // Set up dynamic parameter binding.
1110 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1111 DynamicParameterBinding::DynamicParameter{1, {}},
1112 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1113
1114 TF_ASSERT_OK(RunInference());
1115 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param);
1116 }
1117
TEST_F(DynamicDimensionInferenceTest,SortTest)1118 TEST_F(DynamicDimensionInferenceTest, SortTest) {
1119 auto builder = HloComputation::Builder(TestName());
1120
1121 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1122 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1123 auto size_param = builder.AddInstruction(
1124 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1125
1126 auto compare_builder = HloComputation::Builder("condition");
1127 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1128 0, ShapeUtil::MakeShape(F32, {}), "param1"));
1129 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1130 1, ShapeUtil::MakeShape(F32, {}), "param2"));
1131 compare_builder.AddInstruction(
1132 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1133 HloComputation* compare =
1134 module_->AddEmbeddedComputation(compare_builder.Build());
1135
1136 auto* sort = builder.AddInstruction(HloInstruction::CreateSort(
1137 ShapeUtil::MakeShape(F32, {5, 7}), 1, {data_param}, compare,
1138 /*is_stable=*/false));
1139
1140 module_->AddEntryComputation(builder.Build());
1141 // Set up dynamic parameter binding.
1142 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1143 DynamicParameterBinding::DynamicParameter{1, {}},
1144 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1145
1146 TF_ASSERT_OK(RunInference());
1147 EXPECT_EQ(inference_->GetDynamicSize(sort, {}, 0), size_param);
1148 }
1149
TEST_F(DynamicDimensionInferenceTest,MultiValueSortTest)1150 TEST_F(DynamicDimensionInferenceTest, MultiValueSortTest) {
1151 auto builder = HloComputation::Builder(TestName());
1152
1153 auto shape = ShapeUtil::MakeShape(F32, {5, 7});
1154
1155 auto data_param = builder.AddInstruction(
1156 HloInstruction::CreateParameter(0, shape, "data_param"));
1157 auto size_param = builder.AddInstruction(
1158 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1159
1160 auto compare_builder = HloComputation::Builder("condition");
1161 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1162 0, ShapeUtil::MakeShape(F32, {}), "param1"));
1163 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1164 1, ShapeUtil::MakeShape(F32, {}), "param2"));
1165 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1166 2, ShapeUtil::MakeShape(F32, {}), "param3"));
1167 compare_builder.AddInstruction(HloInstruction::CreateParameter(
1168 3, ShapeUtil::MakeShape(F32, {}), "param4"));
1169 compare_builder.AddInstruction(
1170 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1171 HloComputation* compare =
1172 module_->AddEmbeddedComputation(compare_builder.Build());
1173
1174 auto* sort = builder.AddInstruction(
1175 HloInstruction::CreateSort(ShapeUtil::MakeTupleShape({shape, shape}), 1,
1176 {data_param, data_param}, compare,
1177 /*is_stable=*/false));
1178
1179 module_->AddEntryComputation(builder.Build());
1180 // Set up dynamic parameter binding.
1181 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1182 DynamicParameterBinding::DynamicParameter{1, {}},
1183 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1184
1185 TF_ASSERT_OK(RunInference());
1186 EXPECT_EQ(inference_->GetDynamicSize(sort, {0}, 0), size_param);
1187 EXPECT_EQ(inference_->GetDynamicSize(sort, {1}, 0), size_param);
1188 }
1189
TEST_F(DynamicDimensionInferenceTest,DynamicSliceSingleElementTest)1190 TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
1191 // Slicing out a single element from a dynamic dimension terminates the
1192 // dynamic dimension.
1193 auto builder = HloComputation::Builder(TestName());
1194
1195 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1196 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1197 builder.AddInstruction(
1198 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1199
1200 std::vector<HloInstruction*> params;
1201 for (int i = 0; i < 2; ++i) {
1202 params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1203 i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1204 }
1205
1206 auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1207 ShapeUtil::MakeShape(F32, {1, 1}), data_param, params,
1208 /*slice_sizes=*/{1, 1}));
1209
1210 module_->AddEntryComputation(builder.Build());
1211 // Set up dynamic parameter binding.
1212 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1213 DynamicParameterBinding::DynamicParameter{1, {}},
1214 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1215
1216 TF_ASSERT_OK(RunInference());
1217 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), nullptr);
1218 }
1219
TEST_F(DynamicDimensionInferenceTest,InfersCustomOp)1220 TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
1221 auto builder = HloComputation::Builder(TestName());
1222
1223 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1224 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1225 builder.AddInstruction(
1226 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1227
1228 builder.AddInstruction(HloInstruction::CreateCustomCall(
1229 ShapeUtil::MakeShape(F32, {1, 1}), {data_param}, "MyCustomOp", ""));
1230
1231 module_->AddEntryComputation(builder.Build());
1232
1233 // Set up dynamic parameter binding.
1234 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1235 DynamicParameterBinding::DynamicParameter{1, {}},
1236 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1237
1238 bool handler_called = false;
1239 auto handler = [&](HloInstruction* hlo,
1240 DynamicDimensionInference* inference) {
1241 CHECK(inference != nullptr);
1242 CHECK(Cast<HloCustomCallInstruction>(hlo) != nullptr);
1243 handler_called = true;
1244 return Status::OK();
1245 };
1246 TF_ASSERT_OK(RunInference(handler));
1247
1248 EXPECT_TRUE(handler_called);
1249 }
1250
TEST_F(DynamicDimensionInferenceTest,DynamicReshapeOp)1251 TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
1252 auto builder = HloComputation::Builder(TestName());
1253 auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1254 0, ShapeUtil::MakeShape(F32, {9}), "data_input"));
1255 auto six = builder.AddInstruction(
1256 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
1257 // Creates an input of shape [<=9], dynamic size is 6.
1258 auto dynamic_input =
1259 builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1260 ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0));
1261 auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter(
1262 1, ShapeUtil::MakeShape(S32, {}), "size_param"));
1263 auto three = builder.AddInstruction(
1264 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
1265
1266 // Reshape [<=9] into [3, <=3]
1267
1268 auto dynamic_reshape =
1269 builder.AddInstruction(HloInstruction::CreateDynamicReshape(
1270 ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input,
1271 {three, dynamic_size}));
1272
1273 module_->AddEntryComputation(builder.Build());
1274
1275 TF_ASSERT_OK(RunInference());
1276 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr);
1277 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
1278 }
1279
TEST_F(DynamicDimensionInferenceTest,ReshapeOpWithMultipleDynamicDimensions)1280 TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) {
1281 auto builder = HloComputation::Builder(TestName());
1282 auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1283 0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input"));
1284 auto six = builder.AddInstruction(
1285 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
1286 input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1287 ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0));
1288 auto one = builder.AddInstruction(
1289 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1290 input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1291 ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1));
1292
1293 // Reshape [<=9, <=2] into [<=9, 1, <=2]
1294
1295 auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape(
1296 ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input));
1297
1298 module_->AddEntryComputation(builder.Build());
1299
1300 TF_ASSERT_OK(RunInference());
1301 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six);
1302 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr);
1303 EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one);
1304 }
1305
1306 } // namespace
1307 } // namespace xla
1308