• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/value_inference.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/strings/match.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/statusor.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 namespace {
47 
48 class ValueInferenceTest : public ::testing::Test {
49  public:
TestName() const50   string TestName() const {
51     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
52   }
53 };
54 
55 class DynamismInferenceTest : public ValueInferenceTest {
56  public:
DynamismInferenceTest(se::Platform * platform=nullptr)57   explicit DynamismInferenceTest(se::Platform* platform = nullptr)
58       : platform_(platform) {}
59 
ComputeDynamismLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)60   StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
61                                            Layout* output_layout = nullptr) {
62     ValueInference value_inference(builder);
63     TF_ASSIGN_OR_RETURN(auto literal_slice,
64                         value_inference.AnalyzeIsDynamic(operand));
65     return literal_slice.Clone();
66   }
67 
ComputeDynamismScalar(XlaOp operand,XlaBuilder * builder,ShapeIndex index={})68   StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder,
69                                        ShapeIndex index = {}) {
70     TF_ASSIGN_OR_RETURN(auto literal,
71                         ComputeDynamismLiteral(operand, builder, nullptr));
72     return literal.Get<bool>({}, index);
73   }
74 
75   se::Platform* platform_;
76 };
77 
TEST_F(DynamismInferenceTest,ScalarInt32Literal)78 TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
79   XlaBuilder b(TestName());
80   auto computation = ConstantR0<int32>(&b, 42);
81 
82   auto value = ComputeDynamismScalar(computation, &b);
83   ASSERT_TRUE(value.ok()) << value.status();
84   // A constant is not dynamic.
85   EXPECT_EQ(value.ValueOrDie(), false);
86 }
87 
TEST_F(DynamismInferenceTest,Iota)88 TEST_F(DynamismInferenceTest, Iota) {
89   // The output of iota are consistened static.
90   XlaBuilder b(TestName());
91   auto computation = Iota(&b, S32, 2);
92   // Iota is not dynamic.
93   EXPECT_FALSE(
94       ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0}));
95 }
96 
TEST_F(DynamismInferenceTest,TupleSimple)97 TEST_F(DynamismInferenceTest, TupleSimple) {
98   XlaBuilder b(TestName());
99   auto c = ConstantR0<int32>(&b, 42);
100   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
101 
102   auto tuple = Tuple(&b, {c, p});
103   EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false);
104   EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true);
105 }
106 
TEST_F(DynamismInferenceTest,TupleGteKeepsDynamism)107 TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
108   XlaBuilder b(TestName());
109   auto c = ConstantR0<int32>(&b, 42);
110   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
111 
112   auto tuple = Tuple(&b, {c, p});
113   auto gte0 = GetTupleElement(tuple, 0);
114   auto gte1 = GetTupleElement(tuple, 1);
115   auto tuple_2 = Tuple(&b, {gte0, gte1});
116   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
117   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
118 }
119 
TEST_F(DynamismInferenceTest,PredValueUsedTwice)120 TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
121   XlaBuilder b(TestName());
122   auto c = ConstantR0<int32>(&b, 42);
123   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
124   auto pred = Eq(c, p);
125   auto result = Select(pred, p, c);
126   EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
127 }
128 
TEST_F(DynamismInferenceTest,ReduceUsedTwice)129 TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
130   XlaBuilder b(TestName());
131   auto c = ConstantR0<int32>(&b, 42);
132   auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
133   auto zero = ConstantR0<int32>(&b, 0);
134   XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
135   auto reduce = Reduce(p, zero, add_s32, {0});
136   auto pred = Eq(c, reduce);
137   auto result = Select(pred, reduce, c);
138   EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
139 }
140 
TEST_F(DynamismInferenceTest,DynamicSelectorWithMixedValues)141 TEST_F(DynamismInferenceTest, DynamicSelectorWithMixedValues) {
142   XlaBuilder b(TestName());
143   auto constant_pred = ConstantR1<bool>(&b, {true});
144   auto dynamic_pred = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {1}), "p0");
145   auto concat = ConcatInDim(&b, {constant_pred, dynamic_pred}, 0);
146   auto constant_values = ConstantR1<bool>(&b, {true, true});
147   auto result = Select(concat, constant_values, constant_values);
148   // First result is static (selector is constant, both values are constant).
149   // Iota is not dynamic.
150   EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0}));
151   // Second result is dynamic (selector is dynamic).
152   EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1}));
153 }
154 
TEST_F(DynamismInferenceTest,ConcatSliceReshapeKeepsDynamism)155 TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
156   XlaBuilder b(TestName());
157   auto c = ConstantR0<int32>(&b, 42);
158   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
159 
160   auto concat = ConcatScalars(&b, {c, p});
161   auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
162   auto reshape0 = Reshape(slice0, {});
163   auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
164   auto reshape1 = Reshape(slice1, {});
165   auto tuple_2 = Tuple(&b, {reshape0, reshape1});
166   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
167   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
168 }
169 
TEST_F(DynamismInferenceTest,ParameterIsDynamic)170 TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
171   XlaBuilder b(TestName());
172   auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
173 
174   auto value = ComputeDynamismScalar(computation, &b);
175   ASSERT_TRUE(value.ok()) << value.status();
176   // A parameter is considered dynamic.
177   EXPECT_EQ(value.ValueOrDie(), true);
178 }
179 
TEST_F(DynamismInferenceTest,UnaryOpKeepsDynamism)180 TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
181   XlaBuilder b(TestName());
182   auto c = ConstantR0<int32>(&b, 42);
183   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
184 
185   auto neg0 = Neg(c);
186   auto neg1 = Neg(p);
187   auto tuple_2 = Tuple(&b, {neg0, neg1});
188   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
189   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
190 }
191 
TEST_F(DynamismInferenceTest,ParameterWithToken)192 TEST_F(DynamismInferenceTest, ParameterWithToken) {
193   // Test that token shape can be handled in a parameter.
194   XlaBuilder b(TestName());
195   auto p =
196       Parameter(&b, 0,
197                 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape(),
198                                            ShapeUtil::MakeScalarShape(S32)}),
199                 "p0");
200   EXPECT_EQ(ComputeDynamismScalar(p, &b, {0}).ValueOrDie(), true);
201   EXPECT_EQ(ComputeDynamismScalar(p, &b, {1}).ValueOrDie(), true);
202 }
203 
TEST_F(DynamismInferenceTest,BinaryOpsOrsDynamism)204 TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
205   XlaBuilder b(TestName());
206   auto c = ConstantR0<int32>(&b, 42);
207   auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
208 
209   // Static value + static value = static
210   auto add1 = Add(c, c);
211   // Dynamic value + dynamic value = dynamic
212   auto add2 = Add(p, c);
213   auto tuple_2 = Tuple(&b, {add1, add2});
214   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
215   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
216 }
217 
TEST_F(DynamismInferenceTest,GetDimensionSize)218 TEST_F(DynamismInferenceTest, GetDimensionSize) {
219   XlaBuilder b(TestName());
220   // param = Param([<=2, 3])
221   // get_dimension_size(param, 0) is dynamic
222   // get_dimension_size(param, 1) is static
223   auto p =
224       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
225 
226   auto gds0 = GetDimensionSize(p, 0);
227   auto gds1 = GetDimensionSize(p, 1);
228   auto tuple_2 = Tuple(&b, {gds0, gds1});
229   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true);
230   EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
231 }
232 
TEST_F(DynamismInferenceTest,DynamicSliceWithConstantOperands)233 TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) {
234   XlaBuilder b(TestName());
235 
236   auto constant = ConstantR1<int32>(&b, {0, 1, 2, 3});
237   auto slice_start = ConstantR0(&b, 1);
238   auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1});
239   EXPECT_FALSE(
240       ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0}));
241 }
242 
TEST_F(DynamismInferenceTest,GatherWithCommonParent)243 TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
244   XlaBuilder b(TestName());
245   // Test the analysis on a gather where first operand and second operand have
246   // common parents.
247   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
248 
249   auto operand1 = Parameter(&b, 0, indices_shape, "p1");
250   auto operand2 = Parameter(&b, 1, indices_shape, "p2");
251   auto indices = Sub(operand1, operand2);
252   GatherDimensionNumbers dim_numbers;
253   dim_numbers.add_offset_dims(1);
254   dim_numbers.add_start_index_map(0);
255   dim_numbers.set_index_vector_dim(1);
256   auto gather = Gather(operand1, indices, dim_numbers, {1});
257   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
258   EXPECT_TRUE(
259       ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
260 }
261 
TEST_F(DynamismInferenceTest,GatherWithConstantParent)262 TEST_F(DynamismInferenceTest, GatherWithConstantParent) {
263   XlaBuilder b(TestName());
264   // Test the analysis on a gather.
265   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
266   auto data_operand = ConstantR1<int32>(&b, {1, 2});
267   auto indices = ConstantR1<int32>(&b, {1, 2});
268   GatherDimensionNumbers dim_numbers;
269   dim_numbers.add_offset_dims(1);
270   dim_numbers.add_start_index_map(0);
271   dim_numbers.set_index_vector_dim(1);
272   auto gather = Gather(data_operand, indices, dim_numbers, {1});
273   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
274   // Everything is constant, result is also contant.
275   EXPECT_FALSE(
276       ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
277 }
278 
TEST_F(DynamismInferenceTest,GatherWithSharedConstantParent)279 TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
280   XlaBuilder b(TestName());
281   // Test the analysis on a gather.
282   Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
283   auto operand1 = ConstantR1<int32>(&b, {1, 2});
284   auto operand2 = ConstantR1<int32>(&b, {1, 2});
285   auto indices = Sub(operand1, operand2);
286   GatherDimensionNumbers dim_numbers;
287   dim_numbers.add_offset_dims(1);
288   dim_numbers.add_start_index_map(0);
289   dim_numbers.set_index_vector_dim(1);
290   auto gather = Gather(operand1, indices, dim_numbers, {1});
291   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
292   // Everything is constant, result is also contant.
293   EXPECT_FALSE(
294       ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
295 }
296 
TEST_F(DynamismInferenceTest,InferThroughPad)297 TEST_F(DynamismInferenceTest, InferThroughPad) {
298   XlaBuilder b(TestName());
299   // Test the analysis on a gather.
300   auto operand1 = ConstantR1<int32>(&b, {1, 2});
301   auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
302   PaddingConfig padding_config;
303   padding_config.add_dimensions()->set_edge_padding_high(1);
304   // After pad the value is [constant, constant, parameter].
305   auto pad = Pad(operand1, parameter, padding_config);
306   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
307   // Everything is constant, result is also contant.
308   EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0}));
309   EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1}));
310   EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
311 }
312 
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreSame)313 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) {
314   // The result of following conditional is static.
315   // pred = .. # a dynamic value
316   // if (pred) {
317   //  return (1) # both branches return the same value
318   // } else {
319   //  return (1)
320   // }
321   //
322 
323   auto s32_shape = ShapeUtil::MakeShape(S32, {});
324   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
325   XlaBuilder true_builder("true");
326   Parameter(&true_builder, 0, s32_shape, "cond_param");
327   Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
328   auto true_computation = true_builder.Build().ValueOrDie();
329 
330   XlaBuilder false_builder("false");
331   Parameter(&false_builder, 0, s32_shape, "cond_param");
332   Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 1)});
333   auto false_computation = false_builder.Build().ValueOrDie();
334 
335   XlaBuilder b(TestName());
336   auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
337   auto constant = ConstantR0<int32>(&b, 0);
338   auto cond = Conditional(parameter, constant, true_computation, constant,
339                           false_computation);
340   auto gte = GetTupleElement(cond, 0);
341   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
342   // Result is not dynamic.
343   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
344 }
345 
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreNotSame)346 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
347   // The result of following conditional is dynamic.
348   // pred = .. # a dynamic value
349   // if (pred) {
350   //  return (1) # These two branches return different values.
351   // } else {
352   //  return (2)
353   // }
354   //
355 
356   auto s32_shape = ShapeUtil::MakeShape(S32, {});
357   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
358   XlaBuilder true_builder("true");
359   Parameter(&true_builder, 0, s32_shape, "cond_param");
360   Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
361   auto true_computation = true_builder.Build().ValueOrDie();
362 
363   XlaBuilder false_builder("false");
364   Parameter(&false_builder, 0, s32_shape, "cond_param");
365   Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 2)});
366   auto false_computation = false_builder.Build().ValueOrDie();
367 
368   XlaBuilder b(TestName());
369   auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
370   auto constant = ConstantR0<int32>(&b, 0);
371   auto cond = Conditional(parameter, constant, true_computation, constant,
372                           false_computation);
373   auto gte = GetTupleElement(cond, 0);
374   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
375   // Result is dynamic.
376   EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
377 }
378 
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantTrueBranch)379 TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) {
380   // The result of following conditional is static.
381   // pred = true
382   // if (pred) {
383   //  return (1)
384   // } else {
385   //  return (..dynamic_value...)
386   // }
387   //
388 
389   auto s32_shape = ShapeUtil::MakeShape(S32, {});
390   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
391   XlaBuilder true_builder("true");
392   Parameter(&true_builder, 0, s32_shape, "cond_param");
393   Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
394   auto true_computation = true_builder.Build().ValueOrDie();
395 
396   XlaBuilder false_builder("false");
397   Tuple(&false_builder,
398         {Parameter(&false_builder, 0, s32_shape, "cond_param")});
399   auto false_computation = false_builder.Build().ValueOrDie();
400 
401   XlaBuilder b(TestName());
402   auto pred = ConstantR0<bool>(&b, true);
403   auto constant = ConstantR0<int32>(&b, 0);
404   auto cond = Conditional(pred, constant, true_computation, constant,
405                           false_computation);
406   auto gte = GetTupleElement(cond, 0);
407   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
408   // Result is not dynamic.
409   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
410 }
411 
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantFalseBranch)412 TEST_F(DynamismInferenceTest,
413        InferThroughConditionalPredIsConstantFalseBranch) {
414   // The result of following conditional is dynamic.
415   // pred = false
416   // if (pred) {
417   //  return (1)
418   // } else {
419   //  return (..dynamic_value...)
420   // }
421   //
422 
423   auto s32_shape = ShapeUtil::MakeShape(S32, {});
424   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
425   XlaBuilder true_builder("true");
426   Parameter(&true_builder, 0, s32_shape, "cond_param");
427   Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
428   auto true_computation = true_builder.Build().ValueOrDie();
429 
430   XlaBuilder false_builder("false");
431   Tuple(&false_builder,
432         {Parameter(&false_builder, 0, s32_shape, "cond_param")});
433   auto false_computation = false_builder.Build().ValueOrDie();
434 
435   XlaBuilder b(TestName());
436   auto param = Parameter(&b, 0, s32_shape, "param");
437   auto pred = ConstantR0<bool>(&b, false);
438   auto constant = ConstantR0<int32>(&b, 0);
439   auto cond =
440       Conditional(pred, constant, true_computation, param, false_computation);
441   auto gte = GetTupleElement(cond, 0);
442   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
443   // Result is dynamic.
444   EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
445 }
446 
TEST_F(DynamismInferenceTest,ArgumentForwardingNestedTuple)447 TEST_F(DynamismInferenceTest, ArgumentForwardingNestedTuple) {
448   // The result of following conditional is considered static.
449   // pred = .. dynamic value..
450   //
451   // op = 1
452   // if (pred) {
453   //   if (pred) {
454   //     return op
455   //   } else {
456   //     return op
457   //   }
458   // } else {
459   //   if (pred) {
460   //     return op
461   //   } else {
462   //     return op
463   //   }
464   // }
465   //
466   auto pred_shape = ShapeUtil::MakeShape(PRED, {});
467   auto s32_shape = ShapeUtil::MakeShape(S32, {});
468   auto tuple_shape = ShapeUtil::MakeTupleShape({pred_shape, s32_shape});
469   auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
470   XlaBuilder inner_true_builder("inner_true");
471   Parameter(&inner_true_builder, 0, s32_shape, "cond_param");
472   Tuple(&inner_true_builder, {ConstantR0<int32>(&inner_true_builder, 0)});
473   auto inner_true_computation = inner_true_builder.Build().ValueOrDie();
474 
475   XlaBuilder inner_false_builder("inner_false");
476   Tuple(&inner_false_builder,
477         {Parameter(&inner_false_builder, 0, s32_shape, "cond_param")});
478   auto inner_false_computation = inner_false_builder.Build().ValueOrDie();
479 
480   XlaBuilder true_builder("true");
481   {
482     auto param = Parameter(&true_builder, 0, tuple_shape, "param");
483     auto op = GetTupleElement(param, 1);
484     auto pred = GetTupleElement(param, 0);
485     Conditional(pred, op, inner_true_computation, op, inner_false_computation);
486   }
487   auto true_computation = true_builder.Build().ValueOrDie();
488   XlaBuilder false_builder("false");
489   {
490     auto param = Parameter(&false_builder, 0, tuple_shape, "param");
491     auto op = GetTupleElement(param, 1);
492     auto pred = GetTupleElement(param, 0);
493     Conditional(pred, op, inner_true_computation, op, inner_false_computation);
494   }
495   auto false_computation = false_builder.Build().ValueOrDie();
496   XlaBuilder b(TestName());
497   auto constant = ConstantR0<int32>(&b, 0);
498   auto pred = Parameter(&b, 0, pred_shape, "param");
499   auto param = Tuple(&b, {pred, constant});
500   auto cond =
501       Conditional(pred, param, true_computation, param, false_computation);
502   auto gte = GetTupleElement(cond, 0);
503   ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
504   // Result is static.
505   EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
506 }
507 
508 class UpperBoundInferenceTest : public ValueInferenceTest {
509  public:
UpperBoundInferenceTest(se::Platform * platform=nullptr)510   explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)
511       : platform_(platform) {}
512 
ComputeUpperBoundLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)513   StatusOr<OptionalLiteral> ComputeUpperBoundLiteral(
514       XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
515     ValueInference value_inference(builder);
516     TF_ASSIGN_OR_RETURN(auto literal,
517                         value_inference.AnalyzeConstant(
518                             operand, ValueInferenceMode::kUpperBound));
519     return literal;
520   }
521 
522   se::Platform* platform_;
523 };
524 
TEST_F(UpperBoundInferenceTest,GetDimensionSize)525 TEST_F(UpperBoundInferenceTest, GetDimensionSize) {
526   XlaBuilder b(TestName());
527   auto p =
528       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
529 
530   auto gds0 = GetDimensionSize(p, 0);
531   auto gds1 = GetDimensionSize(p, 1);
532   auto tuple_2 = Tuple(&b, {gds0, gds1});
533   EXPECT_EQ(
534       ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {0}),
535       2);
536   EXPECT_EQ(
537       ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {1}),
538       3);
539 }
540 
TEST_F(UpperBoundInferenceTest,GetDimensionSizeSub)541 TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) {
542   XlaBuilder b(TestName());
543   auto p =
544       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
545 
546   // The range of the first dimension is [0, 2]
547   auto gds0 = GetDimensionSize(p, 0);
548   // The range of the second dimension is [3, 3]
549   auto gds1 = GetDimensionSize(p, 1);
550   // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3
551   auto sub = Sub(gds1, gds0);
552   EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3);
553 }
554 
TEST_F(UpperBoundInferenceTest,GetDimensionSizeDiv)555 TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) {
556   XlaBuilder b(TestName());
557   auto p =
558       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
559   // The range of the first dimension is [0, 2]
560   auto gds0 = GetDimensionSize(p, 0);
561   // The range of the second dimension is [3, 3]
562   auto gds1 = GetDimensionSize(p, 1);
563   // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we
564   // don't use 0 as the lower bound as it would create divide-by-zero error.
565   auto div = Div(gds1, gds0);
566   EXPECT_EQ(ComputeUpperBoundLiteral(div, &b).ValueOrDie().Get<int32>({}), 3);
567 }
568 
TEST_F(UpperBoundInferenceTest,SumSubtract)569 TEST_F(UpperBoundInferenceTest, SumSubtract) {
570   // If x = a, y = b - a
571   // upperbound(x + y) should be upperbound(b)
572   XlaBuilder b(TestName());
573   auto p =
574       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
575   // The range of the first dimension is [0, 2]
576   auto gds0 = GetDimensionSize(p, 0);
577   // The range of the second dimension is [0, 3]
578   auto gds1 = GetDimensionSize(p, 1);
579   auto sub = Sub(gds1, gds0);
580   auto add = Add(sub, gds0);
581   EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3);
582   auto add2 = Add(gds1, gds0);
583   // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
584   auto add3 = Add(sub, add2);
585   EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32>({}), 6);
586 }
587 
TEST_F(UpperBoundInferenceTest,SumSubtractWithDataShuffling)588 TEST_F(UpperBoundInferenceTest, SumSubtractWithDataShuffling) {
589   // Similar to the test above, but with some data shuffling ops in it
590   // (broadcast, slice, reshape, identity convert, etc).
591   XlaBuilder b(TestName());
592   auto p =
593       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
594   // The range of the first dimension is [0, 2]
595   auto gds0 = GetDimensionSize(p, 0);
596   // The range of the second dimension is [0, 3]
597   auto gds1 = GetDimensionSize(p, 1);
598   auto broadcast = Broadcast(gds0, {1, 10});
599   auto convert = ConvertElementType(broadcast, S32);  // Identity convert.
600   auto slice = SliceInDim(convert, /*start_index=*/0, /*limit_index=*/1,
601                           /*stride=*/1, /*dimno=*/1);
602   gds0 = Reshape(slice, {});
603   auto sub = Sub(gds1, gds0);
604   auto add = Add(sub, gds0);
605   EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3);
606   auto add2 = Add(gds1, gds0);
607   // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
608   auto add3 = Add(sub, add2);
609   EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32>({}), 6);
610 }
611 
TEST_F(UpperBoundInferenceTest,SumSubtractEquivalentGetDimensionSize)612 TEST_F(UpperBoundInferenceTest, SumSubtractEquivalentGetDimensionSize) {
613   XlaBuilder b(TestName());
614   auto p =
615       Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
616   // The range of the first dimension is [0, 2]
617   auto gds0 = GetDimensionSize(p, 0);
618   // The range of the second dimension is [0, 3]
619   auto gds1 = GetDimensionSize(p, 1);
620   // gds2 is equivalent to gds0
621   auto gds2 = GetDimensionSize(p, 0);
622   auto sub = Sub(gds1, gds2);
623   auto add = Add(sub, gds0);
624   // upperbound(gds0 + gds1 - gds2) is equal to upperbound(gds1) if gds0 ==
625   // gds2.
626   EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3);
627 }
628 
TEST_F(UpperBoundInferenceTest,ParamCantInferBound)629 TEST_F(UpperBoundInferenceTest, ParamCantInferBound) {
630   // We can infer a parameter's dimension's bound, but not the parameter value's
631   // bound.
632   XlaBuilder b(TestName());
633   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0");
634   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1");
635   auto gds = GetDimensionSize(p0, 0);
636   auto sub = Div(gds, p1);
637   EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b)
638                    .ValueOrDie()
639                    .Get<int32>({})
640                    .has_value());
641 }
642 
TEST_F(UpperBoundInferenceTest,KeyValueSort)643 TEST_F(UpperBoundInferenceTest, KeyValueSort) {
644   XlaBuilder comparator_b("comparator");
645   auto p0 = Parameter(&comparator_b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
646   auto p1 = Parameter(&comparator_b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
647   Parameter(&comparator_b, 2, ShapeUtil::MakeShape(S32, {}), "p2");
648   Parameter(&comparator_b, 3, ShapeUtil::MakeShape(S32, {}), "p3");
649   Compare(p0, p1, ComparisonDirection::kGe);
650   TF_ASSERT_OK_AND_ASSIGN(auto comparator, comparator_b.Build());
651 
652   int64_t elem_count = 17;
653   XlaBuilder b(TestName());
654   auto param = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {elem_count}), "p0");
655   auto iota = Iota(&b, S32, elem_count);
656   auto sort = Sort({param, iota}, comparator);
657   auto gte = GetTupleElement(sort, 1);
658 
659   for (int64_t i = 0; i < elem_count; ++i) {
660     auto result_first_elem =
661         ComputeUpperBoundLiteral(gte, &b).ValueOrDie().Get<int32>({i});
662     // We can infer the bound of sort.
663     EXPECT_TRUE(result_first_elem.has_value());
664     // The bound of the sort result is the max value in the input.
665     EXPECT_EQ(result_first_elem.value(), elem_count - 1);
666   }
667 }
668 
669 class ConstValueInferenceTest : public ValueInferenceTest {
670  public:
ConstValueInferenceTest(se::Platform * platform=nullptr)671   explicit ConstValueInferenceTest(se::Platform* platform = nullptr)
672       : platform_(platform) {}
673 
ComputeConstantValueLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)674   StatusOr<OptionalLiteral> ComputeConstantValueLiteral(
675       XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
676     ValueInference value_inference(builder);
677     TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant(
678                                           operand, ValueInferenceMode::kValue));
679     return literal;
680   }
681 
682   se::Platform* platform_;
683 };
684 
TEST_F(ConstValueInferenceTest,ConstValuePassThroughSetBound)685 TEST_F(ConstValueInferenceTest, ConstValuePassThroughSetBound) {
686   XlaBuilder b(TestName());
687   auto p0 = ConstantR0<int32>(&b, 32);
688   Shape shape = ShapeUtil::MakeShape(S32, {});
689   xla::Literal dynamism = xla::LiteralUtil::CreateR0<bool>(false);
690   xla::Literal bound = xla::LiteralUtil::CreateR0<int32>(32);
691   xla::Literal tuple =
692       xla::LiteralUtil::MakeTupleOwned(std::move(bound), std::move(dynamism));
693   auto set_bound =
694       CustomCall(&b, "SetBound", {p0}, shape, "", false, {}, &tuple);
695   auto result =
696       ComputeConstantValueLiteral(set_bound, &b).ValueOrDie().Get<int32>({});
697   EXPECT_TRUE(result.has_value());
698   EXPECT_EQ(result.value(), 32);
699 }
700 
701 }  // namespace
702 }  // namespace xla
703