• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_runner.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 
36 namespace op = xla::testing::opcode_matchers;
37 
38 namespace xla {
39 namespace {
40 
41 class DynamicDimensionInferenceTest : public HloTestBase {
42  protected:
DynamicDimensionInferenceTest()43   DynamicDimensionInferenceTest() : HloTestBase() {
44     module_ = CreateNewVerifiedModule();
45   }
46 
RunInference(DynamicDimensionInference::CustomCallInferenceHandler handler=nullptr)47   Status RunInference(
48       DynamicDimensionInference::CustomCallInferenceHandler handler = nullptr) {
49     TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
50                         DynamicDimensionInference::Run(module_.get(), handler));
51 
52     inference_ = absl::make_unique<DynamicDimensionInference>(inference);
53     return Status::OK();
54   }
55 
GetAdd()56   HloComputation* GetAdd() {
57     auto embedded_builder = HloComputation::Builder("add");
58     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
59         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
60     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
61         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
62     embedded_builder.AddInstruction(
63         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
64     return module_->AddEmbeddedComputation(embedded_builder.Build());
65   }
66 
GetAddTuple()67   HloComputation* GetAddTuple() {
68     auto embedded_builder = HloComputation::Builder("add");
69     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
70         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
71     auto lhs_1 =
72         embedded_builder.AddInstruction(HloInstruction::CreateParameter(
73             1, ShapeUtil::MakeShape(F32, {}), "lhs.1"));
74     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
75         2, ShapeUtil::MakeShape(F32, {}), "rhs"));
76     auto rhs_1 =
77         embedded_builder.AddInstruction(HloInstruction::CreateParameter(
78             3, ShapeUtil::MakeShape(F32, {}), "rhs.1"));
79     auto add = embedded_builder.AddInstruction(
80         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
81     auto add_1 = embedded_builder.AddInstruction(HloInstruction::CreateBinary(
82         lhs->shape(), HloOpcode::kAdd, lhs_1, rhs_1));
83     embedded_builder.AddInstruction(HloInstruction::CreateTuple({add, add_1}));
84     return module_->AddEmbeddedComputation(embedded_builder.Build());
85   }
86 
GetGe()87   HloComputation* GetGe() {
88     auto embedded_builder = HloComputation::Builder("ge");
89     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
90         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
91     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
92         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
93     embedded_builder.AddInstruction(HloInstruction::CreateCompare(
94         ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
95     return module_->AddEmbeddedComputation(embedded_builder.Build());
96   }
97 
98   std::unique_ptr<HloModule> module_;
99   std::unique_ptr<DynamicDimensionInference> inference_;
100   const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
101 };
102 
TEST_F(DynamicDimensionInferenceTest,ParamTest)103 TEST_F(DynamicDimensionInferenceTest, ParamTest) {
104   auto builder = HloComputation::Builder(TestName());
105   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
106 
107   auto param = builder.AddInstruction(
108       HloInstruction::CreateParameter(0, input_shape, "param"));
109   auto param2 = builder.AddInstruction(
110       HloInstruction::CreateParameter(1, scalar_shape_, "param"));
111 
112   module_->AddEntryComputation(builder.Build());
113   SCOPED_TRACE(module_->ToString());
114 
115   // Set up dynamic parameter binding.
116   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
117       DynamicParameterBinding::DynamicParameter{1, {}},
118       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
119 
120   TF_ASSERT_OK(RunInference());
121   EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2);
122   EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr);
123   EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr);
124 }
125 
TEST_F(DynamicDimensionInferenceTest,ParamTestTuple)126 TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) {
127   auto builder = HloComputation::Builder(TestName());
128   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
129 
130   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
131       0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
132 
133   module_->AddEntryComputation(builder.Build());
134   // Set up dynamic parameter binding.
135   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
136       DynamicParameterBinding::DynamicParameter{0, {1}},
137       DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
138 
139   SCOPED_TRACE(module_->ToString());
140   TF_ASSERT_OK(RunInference());
141   EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
142               op::GetTupleElement(param, 1));
143 
144   EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
145 }
146 
TEST_F(DynamicDimensionInferenceTest,GetTupleElement)147 TEST_F(DynamicDimensionInferenceTest, GetTupleElement) {
148   // When data flows through GTE, the dynamic dimension size keeps the
149   // same, and the index has its front popped.
150   auto builder = HloComputation::Builder(TestName());
151   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
152 
153   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
154       0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
155 
156   auto gte = builder.AddInstruction(
157       HloInstruction::CreateGetTupleElement(input_shape, param, 0));
158 
159   module_->AddEntryComputation(builder.Build());
160   // Set up dynamic parameter binding.
161   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
162       DynamicParameterBinding::DynamicParameter{0, {1}},
163       DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
164 
165   SCOPED_TRACE(module_->ToString());
166   TF_ASSERT_OK(RunInference());
167   EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
168               op::GetTupleElement(param, 1));
169 
170   EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1),
171               op::GetTupleElement(param, 1));
172 
173   EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
174 }
175 
TEST_F(DynamicDimensionInferenceTest,ElementwiseTest)176 TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) {
177   // When data flows through elementwise, the dynamic dimension size keeps the
178   // same.
179   auto builder = HloComputation::Builder(TestName());
180   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
181 
182   auto data_param = builder.AddInstruction(
183       HloInstruction::CreateParameter(0, input_shape, "data_param"));
184   auto size_param = builder.AddInstruction(
185       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
186 
187   auto* negate = builder.AddInstruction(
188       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
189 
190   module_->AddEntryComputation(builder.Build());
191   // Set up dynamic parameter binding.
192   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
193       DynamicParameterBinding::DynamicParameter{1, {}},
194       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
195 
196   SCOPED_TRACE(module_->ToString());
197   TF_ASSERT_OK(RunInference());
198   EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param);
199 }
200 
TEST_F(DynamicDimensionInferenceTest,ReduceTestI)201 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
202   auto builder = HloComputation::Builder(TestName());
203   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
204   auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
205 
206   auto data_param = builder.AddInstruction(
207       HloInstruction::CreateParameter(0, input_shape, "data_param"));
208   auto size_param = builder.AddInstruction(
209       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
210 
211   auto negate = builder.AddInstruction(
212       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
213 
214   auto init = builder.AddInstruction(
215       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
216 
217   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
218       reduce_shape, negate, init, {0, 2}, GetAdd()));
219 
220   module_->AddEntryComputation(builder.Build());
221 
222   // Set up dynamic parameter binding.
223   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
224       DynamicParameterBinding::DynamicParameter{1, {}},
225       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
226 
227   SCOPED_TRACE(module_->ToString());
228   TF_ASSERT_OK(RunInference());
229   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param);
230 }
231 
TEST_F(DynamicDimensionInferenceTest,ReduceTestII)232 TEST_F(DynamicDimensionInferenceTest, ReduceTestII) {
233   // Same as ReduceTestI, but only reduce one dimension.
234   auto builder = HloComputation::Builder(TestName());
235   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
236   auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
237 
238   auto data_param = builder.AddInstruction(
239       HloInstruction::CreateParameter(0, input_shape, "data_param"));
240   auto size_param = builder.AddInstruction(
241       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
242 
243   auto negate = builder.AddInstruction(
244       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
245 
246   auto init = builder.AddInstruction(
247       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
248 
249   auto reduce = builder.AddInstruction(
250       HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd()));
251 
252   module_->AddEntryComputation(builder.Build());
253 
254   // Set up dynamic parameter binding.
255   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
256       DynamicParameterBinding::DynamicParameter{1, {}},
257       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
258 
259   SCOPED_TRACE(module_->ToString());
260   TF_ASSERT_OK(RunInference());
261   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param);
262   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr);
263 }
264 
TEST_F(DynamicDimensionInferenceTest,VariadicReduce)265 TEST_F(DynamicDimensionInferenceTest, VariadicReduce) {
266   // Handle variadic reduce where output is a tuple.
267   auto builder = HloComputation::Builder(TestName());
268   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
269   auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
270 
271   auto data_param_dynamic = builder.AddInstruction(
272       HloInstruction::CreateParameter(0, input_shape, "data_param"));
273   auto data_param_static = builder.AddInstruction(
274       HloInstruction::CreateParameter(1, input_shape, "data_param.2"));
275   auto size_param = builder.AddInstruction(
276       HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
277 
278   // Set up dynamic parameter binding.
279   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
280       DynamicParameterBinding::DynamicParameter{2, {}},
281       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
282 
283   auto dynamic_negate = builder.AddInstruction(HloInstruction::CreateUnary(
284       input_shape, HloOpcode::kNegate, data_param_dynamic));
285 
286   auto static_negate = builder.AddInstruction(HloInstruction::CreateUnary(
287       input_shape, HloOpcode::kNegate, data_param_static));
288 
289   auto init = builder.AddInstruction(
290       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
291 
292   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
293       ShapeUtil::MakeTupleShape({reduce_shape, reduce_shape}),
294       {dynamic_negate, static_negate}, {init, init}, {1}, GetAddTuple()));
295 
296   module_->AddEntryComputation(builder.Build());
297 
298   SCOPED_TRACE(module_->ToString());
299   TF_ASSERT_OK(RunInference());
300   EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 1), size_param);
301   EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 1), size_param);
302   EXPECT_EQ(inference_->GetDynamicSize(reduce, {0}, 0), nullptr);
303   EXPECT_EQ(inference_->GetDynamicSize(reduce, {1}, 0), nullptr);
304 }
305 
TEST_F(DynamicDimensionInferenceTest,DotTest)306 TEST_F(DynamicDimensionInferenceTest, DotTest) {
307   auto builder = HloComputation::Builder(TestName());
308   constexpr int xdim = 3;
309   constexpr int ydim = 2;
310   constexpr int zdim = 1;
311   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
312   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
313   auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
314 
315   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
316       /*parameter_number=*/0, xy_shape, "A"));
317   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
318       /*parameter_number=*/1, yz_shape, "B"));
319   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
320       /*parameter_number=*/2, scalar_shape_, "size_param"));
321 
322   DotDimensionNumbers dot_dnums;
323   dot_dnums.add_lhs_contracting_dimensions(1);
324   dot_dnums.add_rhs_contracting_dimensions(0);
325   auto dot = builder.AddInstruction(
326       HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
327                                 HloTestBase::DefaultPrecisionConfig(2)));
328 
329   module_->AddEntryComputation(builder.Build());
330 
331   // Set up dynamic parameter binding for non-contracting dimension.
332   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
333       DynamicParameterBinding::DynamicParameter{2, {}},
334       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
335 
336   // Set up binding for contracting dimensions.
337   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
338       DynamicParameterBinding::DynamicParameter{2, {}},
339       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
340   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
341       DynamicParameterBinding::DynamicParameter{2, {}},
342       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
343 
344   SCOPED_TRACE(module_->ToString());
345   TF_ASSERT_OK(RunInference());
346   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
347   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
348 }
349 
TEST_F(DynamicDimensionInferenceTest,DotTestBatch)350 TEST_F(DynamicDimensionInferenceTest, DotTestBatch) {
351   auto builder = HloComputation::Builder(TestName());
352   auto lhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
353   auto rhs_shape = ShapeUtil::MakeShape(F32, {4, 128, 2, 8});
354   auto output_shape = ShapeUtil::MakeShape(F32, {4, 2, 128, 128});
355 
356   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
357       /*parameter_number=*/0, lhs_shape, "A"));
358   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
359       /*parameter_number=*/1, rhs_shape, "B"));
360   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
361       /*parameter_number=*/2, scalar_shape_, "size_param"));
362 
363   DotDimensionNumbers dot_dnums;
364   dot_dnums.add_lhs_contracting_dimensions(3);
365   dot_dnums.add_rhs_contracting_dimensions(3);
366   dot_dnums.add_lhs_batch_dimensions(0);
367   dot_dnums.add_lhs_batch_dimensions(2);
368   dot_dnums.add_rhs_batch_dimensions(0);
369   dot_dnums.add_rhs_batch_dimensions(2);
370   auto dot = builder.AddInstruction(
371       HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
372                                 HloTestBase::DefaultPrecisionConfig(2)));
373 
374   module_->AddEntryComputation(builder.Build());
375 
376   // Set up dynamic parameter binding for batch dimension.
377   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
378       DynamicParameterBinding::DynamicParameter{2, {}},
379       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
380 
381   SCOPED_TRACE(module_->ToString());
382   TF_ASSERT_OK(RunInference());
383   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
384   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
385   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
386   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 3), nullptr);
387 }
388 
TEST_F(DynamicDimensionInferenceTest,DotTestMultiContracting)389 TEST_F(DynamicDimensionInferenceTest, DotTestMultiContracting) {
390   auto builder = HloComputation::Builder(TestName());
391   auto lhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 8, 64});
392   auto rhs_shape = ShapeUtil::MakeShape(F32, {2, 2, 512});
393   auto output_shape = ShapeUtil::MakeShape(F32, {8, 64, 512});
394 
395   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
396       /*parameter_number=*/0, lhs_shape, "A"));
397   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
398       /*parameter_number=*/1, rhs_shape, "B"));
399   builder.AddInstruction(HloInstruction::CreateParameter(
400       /*parameter_number=*/2, scalar_shape_, "size_param"));
401 
402   DotDimensionNumbers dot_dnums;
403   dot_dnums.add_lhs_contracting_dimensions(0);
404   dot_dnums.add_lhs_contracting_dimensions(1);
405   dot_dnums.add_rhs_contracting_dimensions(0);
406   dot_dnums.add_rhs_contracting_dimensions(1);
407   auto dot = builder.AddInstruction(
408       HloInstruction::CreateDot(output_shape, a_param, b_param, dot_dnums,
409                                 HloTestBase::DefaultPrecisionConfig(2)));
410 
411   module_->AddEntryComputation(builder.Build());
412 
413   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
414       DynamicParameterBinding::DynamicParameter{2, {}},
415       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
416 
417   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
418       DynamicParameterBinding::DynamicParameter{2, {}},
419       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
420   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
421       DynamicParameterBinding::DynamicParameter{2, {}},
422       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
423 
424   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
425       DynamicParameterBinding::DynamicParameter{2, {}},
426       DynamicParameterBinding::DynamicDimension{1, {}, 1}));
427 
428   SCOPED_TRACE(module_->ToString());
429   TF_ASSERT_OK(RunInference());
430   // Nothing is dynamic in the output.
431   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), nullptr);
432   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
433   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 2), nullptr);
434 }
435 
TEST_F(DynamicDimensionInferenceTest,ConvolutionTest)436 TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
437   auto builder = HloComputation::Builder(TestName());
438   constexpr int xdim = 3;
439   constexpr int ydim = 2;
440   constexpr int zdim = 1;
441   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
442   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
443   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
444 
445   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
446       /*parameter_number=*/0, xy_shape, "A"));
447   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
448       /*parameter_number=*/1, yz_shape, "B"));
449   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
450       /*parameter_number=*/2, scalar_shape_, "size_param"));
451 
452   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
453 
454   dnums.set_kernel_input_feature_dimension(0);
455   dnums.set_kernel_output_feature_dimension(1);
456   dnums.set_input_batch_dimension(0);
457   dnums.set_output_batch_dimension(1);
458   dnums.set_output_feature_dimension(0);
459 
460   Window window;
461 
462   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
463       zx_shape, a_param, b_param, /*feature_group_count=*/1,
464       /*batch_group_count=*/1, window, dnums,
465       HloTestBase::DefaultPrecisionConfig(2)));
466 
467   module_->AddEntryComputation(builder.Build());
468 
469   // Set up dynamic parameter binding for non-contracting dimension.
470   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
471       DynamicParameterBinding::DynamicParameter{2, {}},
472       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
473 
474   // Set up binding for contracting dimensions.
475   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
476       DynamicParameterBinding::DynamicParameter{2, {}},
477       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
478 
479   SCOPED_TRACE(module_->ToString());
480   TF_ASSERT_OK(RunInference());
481   EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param);
482   EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr);
483 }
484 
TEST_F(DynamicDimensionInferenceTest,TransposeTest)485 TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
486   // Test the ability to trace unmodified dimensions
487   auto builder = HloComputation::Builder(TestName());
488   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
489   auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
490 
491   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
492       /*parameter_number=*/0, input_shape, "A"));
493   auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
494       /*parameter_number=*/1, scalar_shape_, "size_param"));
495   auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
496       /*parameter_number=*/2, scalar_shape_, "size_param"));
497   auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
498       /*parameter_number=*/3, scalar_shape_, "size_param"));
499 
500   auto* transpose = builder.AddInstruction(
501       HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0}));
502 
503   module_->AddEntryComputation(builder.Build());
504 
505   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
506       DynamicParameterBinding::DynamicParameter{1, {}},
507       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
508 
509   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
510       DynamicParameterBinding::DynamicParameter{2, {}},
511       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
512 
513   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
514       DynamicParameterBinding::DynamicParameter{3, {}},
515       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
516 
517   SCOPED_TRACE(module_->ToString());
518   TF_ASSERT_OK(RunInference());
519   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
520   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2);
521   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
522 }
523 
TEST_F(DynamicDimensionInferenceTest,NonDescendingTransposeTest)524 TEST_F(DynamicDimensionInferenceTest, NonDescendingTransposeTest) {
525   // Test the ability to trace unmodified dimensions
526   auto builder = HloComputation::Builder(TestName());
527   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
528   auto output_shape = ShapeUtil::MakeShape(F32, {3, 1, 2});
529 
530   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
531       /*parameter_number=*/0, input_shape, "A"));
532   auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
533       /*parameter_number=*/1, scalar_shape_, "size_param"));
534   auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
535       /*parameter_number=*/2, scalar_shape_, "size_param"));
536   auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
537       /*parameter_number=*/3, scalar_shape_, "size_param"));
538 
539   auto* transpose = builder.AddInstruction(
540       HloInstruction::CreateTranspose(output_shape, a_param, {2, 0, 1}));
541 
542   module_->AddEntryComputation(builder.Build());
543 
544   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
545       DynamicParameterBinding::DynamicParameter{1, {}},
546       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
547 
548   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
549       DynamicParameterBinding::DynamicParameter{2, {}},
550       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
551 
552   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
553       DynamicParameterBinding::DynamicParameter{3, {}},
554       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
555 
556   SCOPED_TRACE(module_->ToString());
557   TF_ASSERT_OK(RunInference());
558   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
559   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_1);
560   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_2);
561 }
562 
TEST_F(DynamicDimensionInferenceTest,ReshapeTest)563 TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
564   // Test the ability to trace unmodified reshape dimensions.
565   auto builder = HloComputation::Builder(TestName());
566   auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
567   auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
568 
569   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
570       /*parameter_number=*/0, input_shape, "A"));
571   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
572       /*parameter_number=*/1, scalar_shape_, "size_param"));
573 
574   auto* reshape = builder.AddInstruction(
575       HloInstruction::CreateReshape(output_shape, a_param));
576 
577   module_->AddEntryComputation(builder.Build());
578 
579   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
580       DynamicParameterBinding::DynamicParameter{1, {}},
581       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
582 
583   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
584       DynamicParameterBinding::DynamicParameter{1, {}},
585       DynamicParameterBinding::DynamicDimension{0, {}, 3}));
586 
587   SCOPED_TRACE(module_->ToString());
588   TF_ASSERT_OK(RunInference());
589   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
590   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param);
591   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr);
592   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param);
593   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr);
594   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr);
595 }
596 
TEST_F(DynamicDimensionInferenceTest,ReshapeInferredDimensionTest)597 TEST_F(DynamicDimensionInferenceTest, ReshapeInferredDimensionTest) {
598   // Test the ability to trace inferred dimension when output is bigger than
599   // input.
600   auto builder = HloComputation::Builder(TestName());
601   auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
602   auto output_shape = ShapeUtil::MakeShape(F32, {1, 4, 5});
603 
604   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
605       /*parameter_number=*/0, input_shape, "A"));
606   builder.AddInstruction(HloInstruction::CreateParameter(
607       /*parameter_number=*/1, scalar_shape_, "size_param"));
608 
609   auto* reshape = builder.AddInstruction(HloInstruction::CreateReshape(
610       output_shape, a_param, /*inferred_dimension=*/0));
611 
612   module_->AddEntryComputation(builder.Build());
613 
614   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
615       DynamicParameterBinding::DynamicParameter{1, {}},
616       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
617 
618   SCOPED_TRACE(module_->ToString());
619   TF_ASSERT_OK(RunInference());
620   EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
621 }
622 
TEST_F(DynamicDimensionInferenceTest,ReshapeTestMajorDimension)623 TEST_F(DynamicDimensionInferenceTest, ReshapeTestMajorDimension) {
624   // Test the ability to trace dimension combining.
625   auto builder = HloComputation::Builder(TestName());
626   auto input_shape = ShapeUtil::MakeShape(F32, {32, 10, 4});
627   auto output_shape = ShapeUtil::MakeShape(F32, {320, 4});
628 
629   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
630       /*parameter_number=*/0, input_shape, "A"));
631 
632   builder.AddInstruction(HloInstruction::CreateParameter(
633       /*parameter_number=*/1, scalar_shape_, "size_param"));
634 
635   auto* reshape = builder.AddInstruction(
636       HloInstruction::CreateReshape(output_shape, a_param));
637 
638   module_->AddEntryComputation(builder.Build());
639 
640   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
641       DynamicParameterBinding::DynamicParameter{1, {}},
642       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
643 
644   SCOPED_TRACE(module_->ToString());
645   Status status = RunInference();
646   EXPECT_NE(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
647 }
648 
TEST_F(DynamicDimensionInferenceTest,ReshapeIntoScalar)649 TEST_F(DynamicDimensionInferenceTest, ReshapeIntoScalar) {
650   // Test the ability to a reshape into scalar.
651   auto builder = HloComputation::Builder(TestName());
652   auto input_shape = ShapeUtil::MakeShape(F32, {1});
653   auto output_shape = ShapeUtil::MakeShape(F32, {});
654 
655   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
656       /*parameter_number=*/0, input_shape, "A"));
657 
658   builder.AddInstruction(HloInstruction::CreateParameter(
659       /*parameter_number=*/1, scalar_shape_, "size_param"));
660 
661   builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param));
662 
663   module_->AddEntryComputation(builder.Build());
664 
665   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
666       DynamicParameterBinding::DynamicParameter{1, {}},
667       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
668 
669   SCOPED_TRACE(module_->ToString());
670   TF_CHECK_OK(RunInference());
671 }
672 
TEST_F(DynamicDimensionInferenceTest,GatherTest)673 TEST_F(DynamicDimensionInferenceTest, GatherTest) {
674   const string hlo_text = R"(
675 HloModule TensorFlowGatherV2
676 
677 ENTRY main {
678   operand = s32[20,10]{1,0} parameter(0)
679   indices = s32[32,20] parameter(1)
680   dynamic_size = s32[] parameter(2)
681   ROOT gather = s32[32,20,10]{2,1,0} gather(%operand, %indices),
682                  offset_dims={2},
683                  collapsed_slice_dims={0},
684                  start_index_map={0},
685                  index_vector_dim=2,
686                  slice_sizes={1,10}
687 }
688 )";
689 
690   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
691   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
692       DynamicParameterBinding::DynamicParameter{2, {}},
693       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
694   SCOPED_TRACE(module_->ToString());
695   TF_ASSERT_OK(RunInference());
696   EXPECT_EQ(inference_->GetDynamicSize(
697                 module_->entry_computation()->root_instruction(), {}, 0),
698             module_->entry_computation()->parameter_instruction(2));
699 }
700 
TEST_F(DynamicDimensionInferenceTest,BroadcastTest)701 TEST_F(DynamicDimensionInferenceTest, BroadcastTest) {
702   // Test the ability to trace broadcast dimension.
703   auto builder = HloComputation::Builder(TestName());
704   auto input_shape = ShapeUtil::MakeShape(F32, {2});
705   auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
706 
707   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
708       /*parameter_number=*/0, input_shape, "A"));
709   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
710       /*parameter_number=*/1, scalar_shape_, "size_param"));
711 
712   auto* broadcast = builder.AddInstruction(
713       HloInstruction::CreateBroadcast(output_shape, a_param, {1}));
714 
715   module_->AddEntryComputation(builder.Build());
716 
717   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
718       DynamicParameterBinding::DynamicParameter{1, {}},
719       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
720 
721   SCOPED_TRACE(module_->ToString());
722   TF_ASSERT_OK(RunInference());
723   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr);
724   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param);
725   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr);
726 }
727 
TEST_F(DynamicDimensionInferenceTest,WhileTest)728 TEST_F(DynamicDimensionInferenceTest, WhileTest) {
729   // Test the ability to trace into while loops.
730   auto builder = HloComputation::Builder(TestName());
731   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
732   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
733   auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
734 
735   // Body:
736   //
737   //   Param
738   //   |  |
739   // GTE1 GTE2
740   //   |  |
741   //    ADD
742   auto body_builder = HloComputation::Builder("body");
743   auto body_param = body_builder.AddInstruction(
744       HloInstruction::CreateParameter(0, tuple_shape, "param"));
745   auto gte_0 = body_builder.AddInstruction(
746       HloInstruction::CreateGetTupleElement(input_shape, body_param, 0));
747   auto gte_1 = body_builder.AddInstruction(
748       HloInstruction::CreateGetTupleElement(input_shape, body_param, 1));
749   auto add = body_builder.AddInstruction(
750       HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1));
751   body_builder.AddInstruction(HloInstruction::CreateTuple({add, add}));
752 
753   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
754 
755   auto cond_builder = HloComputation::Builder("condition");
756   cond_builder.AddInstruction(
757       HloInstruction::CreateParameter(0, tuple_shape, "param"));
758   cond_builder.AddInstruction(
759       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
760   HloComputation* condition =
761       module_->AddEmbeddedComputation(cond_builder.Build());
762 
763   // Entry:
764   //
765   //  Param
766   //   |
767   //  While
768   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
769       /*parameter_number=*/0, tuple_shape, "A"));
770   builder.AddInstruction(HloInstruction::CreateParameter(
771       /*parameter_number=*/1, scalar_shape_, "size_param"));
772   builder.AddInstruction(
773       HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
774 
775   module_->AddEntryComputation(builder.Build());
776 
777   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
778       DynamicParameterBinding::DynamicParameter{1, {}},
779       DynamicParameterBinding::DynamicDimension{0, {0}, 0}));
780 
781   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
782       DynamicParameterBinding::DynamicParameter{1, {}},
783       DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
784 
785   TF_ASSERT_OK(RunInference());
786   HloInstruction* while_hlo = nullptr;
787   // The while hlo has been replaced, find the new one.
788   for (HloInstruction* inst : module_->entry_computation()->instructions()) {
789     if (inst->opcode() == HloOpcode::kWhile) {
790       while_hlo = inst;
791     }
792   }
793   ASSERT_NE(while_hlo, nullptr);
794   // The original while shape has 2 parameters. With dynamic size, the tuple
795   // should have 4 elements (We don't deduplicate the arguments).
796   EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 4);
797   HloInstruction* add_inst = nullptr;
798   for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
799     if (inst->opcode() == HloOpcode::kAdd) {
800       add_inst = inst;
801     }
802   }
803   EXPECT_NE(add_inst, nullptr);
804   EXPECT_NE(inference_->GetDynamicSize(add_inst, {}, 0), nullptr);
805   EXPECT_NE(inference_->GetDynamicSize(
806                 module_->entry_computation()->root_instruction(), {0}, 0),
807             nullptr);
808   EXPECT_NE(inference_->GetDynamicSize(
809                 module_->entry_computation()->root_instruction(), {1}, 0),
810             nullptr);
811 }
812 
TEST_F(DynamicDimensionInferenceTest,ConditionalInputTest)813 TEST_F(DynamicDimensionInferenceTest, ConditionalInputTest) {
814   // Test the ability to trace into contional loops.
815   auto builder = HloComputation::Builder(TestName());
816   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
817   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
818   // In this test we set inputs to different branches to different shapes.
819   auto tuple_shape_1 = ShapeUtil::MakeTupleShape({input_shape});
820   auto tuple_shape_2 = ShapeUtil::MakeTupleShape({input_shape, input_shape});
821   auto tuple_shape_3 =
822       ShapeUtil::MakeTupleShape({input_shape, input_shape, input_shape});
823 
824   // true branch:
825   //
826   //   Param
827   //   |  |
828   // GTE1 GTE2
829   //   |  |
830   // Tuple(ADD)
831   auto true_builder = HloComputation::Builder("true");
832   {
833     auto true_param = true_builder.AddInstruction(
834         HloInstruction::CreateParameter(0, tuple_shape_2, "param"));
835     auto gte_0 = true_builder.AddInstruction(
836         HloInstruction::CreateGetTupleElement(input_shape, true_param, 0));
837     auto gte_1 = true_builder.AddInstruction(
838         HloInstruction::CreateGetTupleElement(input_shape, true_param, 1));
839     auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
840         input_shape, HloOpcode::kAdd, gte_0, gte_1));
841     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
842   }
843   HloComputation* true_branch =
844       module_->AddEmbeddedComputation(true_builder.Build());
845   // false branch:
846   //
847   //      Param
848   //  |     |    |
849   // GTE1  GTE2 GTE3
850   //        |     |
851   //       Tuple(ADD)
852   auto false_builder = HloComputation::Builder("false");
853   {
854     auto false_param = false_builder.AddInstruction(
855         HloInstruction::CreateParameter(0, tuple_shape_3, "param"));
856     auto gte_0 = false_builder.AddInstruction(
857         HloInstruction::CreateGetTupleElement(input_shape, false_param, 1));
858     auto gte_1 = false_builder.AddInstruction(
859         HloInstruction::CreateGetTupleElement(input_shape, false_param, 2));
860     auto add = false_builder.AddInstruction(HloInstruction::CreateBinary(
861         input_shape, HloOpcode::kAdd, gte_0, gte_1));
862     false_builder.AddInstruction(HloInstruction::CreateTuple({add}));
863   }
864   HloComputation* false_branch =
865       module_->AddEmbeddedComputation(false_builder.Build());
866 
867   // Entry:
868   //
869   //  Param(bool) Param2 (tuple_2) Param3(tuple_3)
870   //   |            |                 |
871   //   +---------Condition------------+
872   auto* pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
873       /*parameter_number=*/0, ShapeUtil::MakeScalarShape(PRED), "pred"));
874 
875   auto* tuple_2_param = builder.AddInstruction(HloInstruction::CreateParameter(
876       /*parameter_number=*/1, tuple_shape_2, "tuple_2_param"));
877   auto* tuple_3_param = builder.AddInstruction(HloInstruction::CreateParameter(
878       /*parameter_number=*/2, tuple_shape_3, "tuple_3_param"));
879   builder.AddInstruction(HloInstruction::CreateParameter(
880       /*parameter_number=*/3, scalar_shape_, "size_param"));
881   builder.AddInstruction(HloInstruction::CreateConditional(
882       tuple_shape_1, pred_param, tuple_2_param, true_branch, tuple_3_param,
883       false_branch));
884 
885   module_->AddEntryComputation(builder.Build());
886 
887   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
888       DynamicParameterBinding::DynamicParameter{3, {}},
889       DynamicParameterBinding::DynamicDimension{1, {0}, 0}));
890   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
891       DynamicParameterBinding::DynamicParameter{3, {}},
892       DynamicParameterBinding::DynamicDimension{1, {1}, 0}));
893   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
894       DynamicParameterBinding::DynamicParameter{3, {}},
895       DynamicParameterBinding::DynamicDimension{2, {1}, 0}));
896   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
897       DynamicParameterBinding::DynamicParameter{3, {}},
898       DynamicParameterBinding::DynamicDimension{2, {2}, 0}));
899 
900   TF_ASSERT_OK(RunInference());
901 
902   HloInstruction* conditional_hlo = nullptr;
903   // The while hlo has been replaced, find the new one.
904   for (HloInstruction* inst : module_->entry_computation()->instructions()) {
905     if (inst->opcode() == HloOpcode::kConditional) {
906       conditional_hlo = inst;
907     }
908   }
909   ASSERT_NE(conditional_hlo, nullptr);
910   // The original conditional shape has 1 parameters. With dynamic size passed
911   // out from the computation, another element is added to the tuple.
912   EXPECT_EQ(conditional_hlo->shape().tuple_shapes_size(), 2);
913   HloInstruction* add_true_branch = nullptr;
914   for (HloInstruction* inst :
915        conditional_hlo->true_computation()->instructions()) {
916     if (inst->opcode() == HloOpcode::kAdd) {
917       add_true_branch = inst;
918     }
919   }
920   EXPECT_NE(add_true_branch, nullptr);
921   EXPECT_NE(inference_->GetDynamicSize(add_true_branch, {}, 0), nullptr);
922 
923   HloInstruction* add_false_branch = nullptr;
924   for (HloInstruction* inst :
925        conditional_hlo->false_computation()->instructions()) {
926     if (inst->opcode() == HloOpcode::kAdd) {
927       add_false_branch = inst;
928     }
929   }
930   EXPECT_NE(add_false_branch, nullptr);
931   EXPECT_NE(inference_->GetDynamicSize(add_false_branch, {}, 0), nullptr);
932 
933   EXPECT_NE(inference_->GetDynamicSize(conditional_hlo, {0}, 0), nullptr);
934 }
935 
TEST_F(DynamicDimensionInferenceTest,ReduceWindowBatchTest)936 TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
937   // Test the ability to trace reduce window batch dimensions.
938   auto builder = HloComputation::Builder(TestName());
939   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
940   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
941 
942   Window window;
943   // First dimension is unchanged.
944   WindowDimension* batch_dim = window.add_dimensions();
945   batch_dim->set_size(1);
946   batch_dim->set_stride(1);
947   batch_dim->set_padding_low(0);
948   batch_dim->set_padding_high(0);
949   batch_dim->set_window_dilation(1);
950   batch_dim->set_base_dilation(1);
951 
952   // Second and third dimension are reduced.
953   for (int64 i = 0; i < 2; ++i) {
954     WindowDimension* dim = window.add_dimensions();
955     dim->set_size(2);
956     dim->set_stride(2);
957     dim->set_padding_low(0);
958     dim->set_padding_high(0);
959     dim->set_window_dilation(1);
960     dim->set_base_dilation(1);
961   }
962 
963   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
964       /*parameter_number=*/0, input_shape, "A"));
965   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
966       /*parameter_number=*/1, scalar_shape_, "size_param"));
967 
968   auto init = builder.AddInstruction(
969       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
970 
971   auto* reduce_window =
972       builder.AddInstruction(HloInstruction::CreateReduceWindow(
973           output_shape, a_param, init, window, GetAdd()));
974 
975   module_->AddEntryComputation(builder.Build());
976 
977   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
978       DynamicParameterBinding::DynamicParameter{1, {}},
979       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
980 
981   SCOPED_TRACE(module_->ToString());
982   TF_ASSERT_OK(RunInference());
983   EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param);
984 }
985 
TEST_F(DynamicDimensionInferenceTest,SelectAndScatterTest)986 TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) {
987   // Test the ability to trace select and scatter batch dimensions.
988   auto builder = HloComputation::Builder(TestName());
989   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
990   auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
991 
992   Window window;
993   // First dimension is unchanged.
994   WindowDimension* batch_dim = window.add_dimensions();
995   batch_dim->set_size(1);
996   batch_dim->set_stride(1);
997   batch_dim->set_padding_low(0);
998   batch_dim->set_padding_high(0);
999   batch_dim->set_window_dilation(1);
1000   batch_dim->set_base_dilation(1);
1001 
1002   // Second and third dimension are reduced.
1003   for (int64 i = 0; i < 2; ++i) {
1004     WindowDimension* dim = window.add_dimensions();
1005     dim->set_size(2);
1006     dim->set_stride(2);
1007     dim->set_padding_low(0);
1008     dim->set_padding_high(0);
1009     dim->set_window_dilation(1);
1010     dim->set_base_dilation(1);
1011   }
1012 
1013   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
1014       /*parameter_number=*/0, input_shape, "A"));
1015   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
1016       /*parameter_number=*/1, scalar_shape_, "size_param"));
1017   auto* source = builder.AddInstruction(HloInstruction::CreateParameter(
1018       /*parameter_number=*/2, source_shape, "B"));
1019 
1020   auto init = builder.AddInstruction(
1021       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
1022 
1023   auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
1024       input_shape, a_param, GetGe(), window, source, init, GetAdd()));
1025 
1026   module_->AddEntryComputation(builder.Build());
1027 
1028   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1029       DynamicParameterBinding::DynamicParameter{1, {}},
1030       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1031   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1032       DynamicParameterBinding::DynamicParameter{1, {}},
1033       DynamicParameterBinding::DynamicDimension{2, {}, 0}));
1034 
1035   SCOPED_TRACE(module_->ToString());
1036   TF_ASSERT_OK(RunInference());
1037   EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param);
1038 }
1039 
TEST_F(DynamicDimensionInferenceTest,ConcateTest)1040 TEST_F(DynamicDimensionInferenceTest, ConcateTest) {
1041   // Concat two data params.
1042   auto builder = HloComputation::Builder(TestName());
1043 
1044   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1045       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param_1"));
1046   auto data_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
1047       1, ShapeUtil::MakeShape(F32, {5, 8}), "data_param_2"));
1048   auto size_param = builder.AddInstruction(
1049       HloInstruction::CreateParameter(2, scalar_shape_, "size_param"));
1050 
1051   auto* concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
1052       ShapeUtil::MakeShape(F32, {5, 15}), {data_param, data_param_2}, 1));
1053 
1054   module_->AddEntryComputation(builder.Build());
1055   // Set up dynamic parameter binding.
1056   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1057       DynamicParameterBinding::DynamicParameter{2, {}},
1058       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1059 
1060   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1061       DynamicParameterBinding::DynamicParameter{2, {}},
1062       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
1063 
1064   TF_ASSERT_OK(RunInference());
1065   EXPECT_EQ(inference_->GetDynamicSize(concat, {}, 0), size_param);
1066 }
1067 
TEST_F(DynamicDimensionInferenceTest,SliceTest)1068 TEST_F(DynamicDimensionInferenceTest, SliceTest) {
1069   auto builder = HloComputation::Builder(TestName());
1070 
1071   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1072       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1073   auto size_param = builder.AddInstruction(
1074       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1075 
1076   auto* slice = builder.AddInstruction(HloInstruction::CreateSlice(
1077       ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0},
1078       /*limit_indices=*/{5, 7}, /*strides=*/{1, 1}));
1079 
1080   module_->AddEntryComputation(builder.Build());
1081   // Set up dynamic parameter binding.
1082   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1083       DynamicParameterBinding::DynamicParameter{1, {}},
1084       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
1085 
1086   TF_ASSERT_OK(RunInference());
1087   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param);
1088 }
1089 
TEST_F(DynamicDimensionInferenceTest,DynamicSliceTest)1090 TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
1091   auto builder = HloComputation::Builder(TestName());
1092 
1093   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1094       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1095   auto size_param = builder.AddInstruction(
1096       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1097 
1098   std::vector<HloInstruction*> params;
1099   for (int i = 0; i < 2; ++i) {
1100     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1101         i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1102   }
1103 
1104   auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1105       ShapeUtil::MakeShape(F32, {5, 1}), data_param, params,
1106       /*slice_sizes=*/{5, 1}));
1107 
1108   module_->AddEntryComputation(builder.Build());
1109   // Set up dynamic parameter binding.
1110   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1111       DynamicParameterBinding::DynamicParameter{1, {}},
1112       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1113 
1114   TF_ASSERT_OK(RunInference());
1115   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param);
1116 }
1117 
TEST_F(DynamicDimensionInferenceTest,SortTest)1118 TEST_F(DynamicDimensionInferenceTest, SortTest) {
1119   auto builder = HloComputation::Builder(TestName());
1120 
1121   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1122       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1123   auto size_param = builder.AddInstruction(
1124       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1125 
1126   auto compare_builder = HloComputation::Builder("condition");
1127   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1128       0, ShapeUtil::MakeShape(F32, {}), "param1"));
1129   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1130       1, ShapeUtil::MakeShape(F32, {}), "param2"));
1131   compare_builder.AddInstruction(
1132       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1133   HloComputation* compare =
1134       module_->AddEmbeddedComputation(compare_builder.Build());
1135 
1136   auto* sort = builder.AddInstruction(HloInstruction::CreateSort(
1137       ShapeUtil::MakeShape(F32, {5, 7}), 1, {data_param}, compare,
1138       /*is_stable=*/false));
1139 
1140   module_->AddEntryComputation(builder.Build());
1141   // Set up dynamic parameter binding.
1142   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1143       DynamicParameterBinding::DynamicParameter{1, {}},
1144       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1145 
1146   TF_ASSERT_OK(RunInference());
1147   EXPECT_EQ(inference_->GetDynamicSize(sort, {}, 0), size_param);
1148 }
1149 
TEST_F(DynamicDimensionInferenceTest,MultiValueSortTest)1150 TEST_F(DynamicDimensionInferenceTest, MultiValueSortTest) {
1151   auto builder = HloComputation::Builder(TestName());
1152 
1153   auto shape = ShapeUtil::MakeShape(F32, {5, 7});
1154 
1155   auto data_param = builder.AddInstruction(
1156       HloInstruction::CreateParameter(0, shape, "data_param"));
1157   auto size_param = builder.AddInstruction(
1158       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1159 
1160   auto compare_builder = HloComputation::Builder("condition");
1161   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1162       0, ShapeUtil::MakeShape(F32, {}), "param1"));
1163   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1164       1, ShapeUtil::MakeShape(F32, {}), "param2"));
1165   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1166       2, ShapeUtil::MakeShape(F32, {}), "param3"));
1167   compare_builder.AddInstruction(HloInstruction::CreateParameter(
1168       3, ShapeUtil::MakeShape(F32, {}), "param4"));
1169   compare_builder.AddInstruction(
1170       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1171   HloComputation* compare =
1172       module_->AddEmbeddedComputation(compare_builder.Build());
1173 
1174   auto* sort = builder.AddInstruction(
1175       HloInstruction::CreateSort(ShapeUtil::MakeTupleShape({shape, shape}), 1,
1176                                  {data_param, data_param}, compare,
1177                                  /*is_stable=*/false));
1178 
1179   module_->AddEntryComputation(builder.Build());
1180   // Set up dynamic parameter binding.
1181   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1182       DynamicParameterBinding::DynamicParameter{1, {}},
1183       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1184 
1185   TF_ASSERT_OK(RunInference());
1186   EXPECT_EQ(inference_->GetDynamicSize(sort, {0}, 0), size_param);
1187   EXPECT_EQ(inference_->GetDynamicSize(sort, {1}, 0), size_param);
1188 }
1189 
TEST_F(DynamicDimensionInferenceTest,DynamicSliceSingleElementTest)1190 TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) {
1191   // Slicing out a single element from a dynamic dimension terminates the
1192   // dynamic dimension.
1193   auto builder = HloComputation::Builder(TestName());
1194 
1195   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1196       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1197   builder.AddInstruction(
1198       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1199 
1200   std::vector<HloInstruction*> params;
1201   for (int i = 0; i < 2; ++i) {
1202     params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
1203         i + 2, ShapeUtil::MakeShape(S32, {}), "slice_indices")));
1204   }
1205 
1206   auto* slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
1207       ShapeUtil::MakeShape(F32, {1, 1}), data_param, params,
1208       /*slice_sizes=*/{1, 1}));
1209 
1210   module_->AddEntryComputation(builder.Build());
1211   // Set up dynamic parameter binding.
1212   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1213       DynamicParameterBinding::DynamicParameter{1, {}},
1214       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1215 
1216   TF_ASSERT_OK(RunInference());
1217   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), nullptr);
1218 }
1219 
TEST_F(DynamicDimensionInferenceTest,InfersCustomOp)1220 TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) {
1221   auto builder = HloComputation::Builder(TestName());
1222 
1223   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
1224       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
1225   builder.AddInstruction(
1226       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
1227 
1228   builder.AddInstruction(HloInstruction::CreateCustomCall(
1229       ShapeUtil::MakeShape(F32, {1, 1}), {data_param}, "MyCustomOp", ""));
1230 
1231   module_->AddEntryComputation(builder.Build());
1232 
1233   // Set up dynamic parameter binding.
1234   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
1235       DynamicParameterBinding::DynamicParameter{1, {}},
1236       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
1237 
1238   bool handler_called = false;
1239   auto handler = [&](HloInstruction* hlo,
1240                      DynamicDimensionInference* inference) {
1241     CHECK(inference != nullptr);
1242     CHECK(Cast<HloCustomCallInstruction>(hlo) != nullptr);
1243     handler_called = true;
1244     return Status::OK();
1245   };
1246   TF_ASSERT_OK(RunInference(handler));
1247 
1248   EXPECT_TRUE(handler_called);
1249 }
1250 
TEST_F(DynamicDimensionInferenceTest,DynamicReshapeOp)1251 TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) {
1252   auto builder = HloComputation::Builder(TestName());
1253   auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1254       0, ShapeUtil::MakeShape(F32, {9}), "data_input"));
1255   auto six = builder.AddInstruction(
1256       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
1257   // Creates an input of shape [<=9], dynamic size is 6.
1258   auto dynamic_input =
1259       builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1260           ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0));
1261   auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter(
1262       1, ShapeUtil::MakeShape(S32, {}), "size_param"));
1263   auto three = builder.AddInstruction(
1264       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
1265 
1266   // Reshape [<=9] into [3, <=3]
1267 
1268   auto dynamic_reshape =
1269       builder.AddInstruction(HloInstruction::CreateDynamicReshape(
1270           ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input,
1271           {three, dynamic_size}));
1272 
1273   module_->AddEntryComputation(builder.Build());
1274 
1275   TF_ASSERT_OK(RunInference());
1276   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr);
1277   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size);
1278 }
1279 
TEST_F(DynamicDimensionInferenceTest,ReshapeOpWithMultipleDynamicDimensions)1280 TEST_F(DynamicDimensionInferenceTest, ReshapeOpWithMultipleDynamicDimensions) {
1281   auto builder = HloComputation::Builder(TestName());
1282   auto input = builder.AddInstruction(HloInstruction::CreateParameter(
1283       0, ShapeUtil::MakeShape(F32, {9, 2}), "data_input"));
1284   auto six = builder.AddInstruction(
1285       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(6)));
1286   input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1287       ShapeUtil::MakeShape(F32, {9, 2}, {true, false}), input, six, 0));
1288   auto one = builder.AddInstruction(
1289       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
1290   input = builder.AddInstruction(HloInstruction::CreateSetDimensionSize(
1291       ShapeUtil::MakeShape(F32, {9, 2}, {true, true}), input, one, 1));
1292 
1293   // Reshape [<=9, <=2] into [<=9, 1, <=2]
1294 
1295   auto dynamic_reshape = builder.AddInstruction(HloInstruction::CreateReshape(
1296       ShapeUtil::MakeShape(F32, {9, 1, 2}, {true, false, true}), input));
1297 
1298   module_->AddEntryComputation(builder.Build());
1299 
1300   TF_ASSERT_OK(RunInference());
1301   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), six);
1302   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), nullptr);
1303   EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 2), one);
1304 }
1305 
1306 }  // namespace
1307 }  // namespace xla
1308