1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/value_inference.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/strings/match.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/client_library.h"
25 #include "tensorflow/compiler/xla/client/global_data.h"
26 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
27 #include "tensorflow/compiler/xla/client/lib/prng.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/tests/test_utils.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/status.h"
42 #include "tensorflow/core/platform/statusor.h"
43 #include "tensorflow/core/platform/types.h"
44
45 namespace xla {
46 namespace {
47
48 class ValueInferenceTest : public ::testing::Test {
49 public:
TestName() const50 string TestName() const {
51 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
52 }
53 };
54
55 class DynamismInferenceTest : public ValueInferenceTest {
56 public:
DynamismInferenceTest(se::Platform * platform=nullptr)57 explicit DynamismInferenceTest(se::Platform* platform = nullptr)
58 : platform_(platform) {}
59
ComputeDynamismLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)60 StatusOr<Literal> ComputeDynamismLiteral(XlaOp operand, XlaBuilder* builder,
61 Layout* output_layout = nullptr) {
62 ValueInference value_inference(builder);
63 TF_ASSIGN_OR_RETURN(auto literal_slice,
64 value_inference.AnalyzeIsDynamic(operand));
65 return literal_slice.Clone();
66 }
67
ComputeDynamismScalar(XlaOp operand,XlaBuilder * builder,ShapeIndex index={})68 StatusOr<bool> ComputeDynamismScalar(XlaOp operand, XlaBuilder* builder,
69 ShapeIndex index = {}) {
70 TF_ASSIGN_OR_RETURN(auto literal,
71 ComputeDynamismLiteral(operand, builder, nullptr));
72 return literal.Get<bool>({}, index);
73 }
74
75 se::Platform* platform_;
76 };
77
TEST_F(DynamismInferenceTest,ScalarInt32Literal)78 TEST_F(DynamismInferenceTest, ScalarInt32Literal) {
79 XlaBuilder b(TestName());
80 auto computation = ConstantR0<int32>(&b, 42);
81
82 auto value = ComputeDynamismScalar(computation, &b);
83 ASSERT_TRUE(value.ok()) << value.status();
84 // A constant is not dynamic.
85 EXPECT_EQ(value.ValueOrDie(), false);
86 }
87
TEST_F(DynamismInferenceTest,Iota)88 TEST_F(DynamismInferenceTest, Iota) {
89 // The output of iota are consistened static.
90 XlaBuilder b(TestName());
91 auto computation = Iota(&b, S32, 2);
92 // Iota is not dynamic.
93 EXPECT_FALSE(
94 ComputeDynamismLiteral(computation, &b).ValueOrDie().Get<bool>({0}));
95 }
96
TEST_F(DynamismInferenceTest,TupleSimple)97 TEST_F(DynamismInferenceTest, TupleSimple) {
98 XlaBuilder b(TestName());
99 auto c = ConstantR0<int32>(&b, 42);
100 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
101
102 auto tuple = Tuple(&b, {c, p});
103 EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {0}).ValueOrDie(), false);
104 EXPECT_EQ(ComputeDynamismScalar(tuple, &b, {1}).ValueOrDie(), true);
105 }
106
TEST_F(DynamismInferenceTest,TupleGteKeepsDynamism)107 TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) {
108 XlaBuilder b(TestName());
109 auto c = ConstantR0<int32>(&b, 42);
110 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
111
112 auto tuple = Tuple(&b, {c, p});
113 auto gte0 = GetTupleElement(tuple, 0);
114 auto gte1 = GetTupleElement(tuple, 1);
115 auto tuple_2 = Tuple(&b, {gte0, gte1});
116 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
117 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
118 }
119
TEST_F(DynamismInferenceTest,PredValueUsedTwice)120 TEST_F(DynamismInferenceTest, PredValueUsedTwice) {
121 XlaBuilder b(TestName());
122 auto c = ConstantR0<int32>(&b, 42);
123 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
124 auto pred = Eq(c, p);
125 auto result = Select(pred, p, c);
126 EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
127 }
128
TEST_F(DynamismInferenceTest,ReduceUsedTwice)129 TEST_F(DynamismInferenceTest, ReduceUsedTwice) {
130 XlaBuilder b(TestName());
131 auto c = ConstantR0<int32>(&b, 42);
132 auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}), "p0");
133 auto zero = ConstantR0<int32>(&b, 0);
134 XlaComputation add_s32 = CreateScalarAddComputation(S32, &b);
135 auto reduce = Reduce(p, zero, add_s32, {0});
136 auto pred = Eq(c, reduce);
137 auto result = Select(pred, reduce, c);
138 EXPECT_EQ(ComputeDynamismScalar(result, &b, {}).ValueOrDie(), true);
139 }
140
TEST_F(DynamismInferenceTest,DynamicSelectorWithMixedValues)141 TEST_F(DynamismInferenceTest, DynamicSelectorWithMixedValues) {
142 XlaBuilder b(TestName());
143 auto constant_pred = ConstantR1<bool>(&b, {true});
144 auto dynamic_pred = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {1}), "p0");
145 auto concat = ConcatInDim(&b, {constant_pred, dynamic_pred}, 0);
146 auto constant_values = ConstantR1<bool>(&b, {true, true});
147 auto result = Select(concat, constant_values, constant_values);
148 // First result is static (selector is constant, both values are constant).
149 // Iota is not dynamic.
150 EXPECT_FALSE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({0}));
151 // Second result is dynamic (selector is dynamic).
152 EXPECT_TRUE(ComputeDynamismLiteral(result, &b).ValueOrDie().Get<bool>({1}));
153 }
154
TEST_F(DynamismInferenceTest,ConcatSliceReshapeKeepsDynamism)155 TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) {
156 XlaBuilder b(TestName());
157 auto c = ConstantR0<int32>(&b, 42);
158 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
159
160 auto concat = ConcatScalars(&b, {c, p});
161 auto slice0 = SliceInDim(concat, 0, 1, 1, 0);
162 auto reshape0 = Reshape(slice0, {});
163 auto slice1 = SliceInDim(concat, 1, 2, 1, 0);
164 auto reshape1 = Reshape(slice1, {});
165 auto tuple_2 = Tuple(&b, {reshape0, reshape1});
166 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
167 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
168 }
169
TEST_F(DynamismInferenceTest,ParameterIsDynamic)170 TEST_F(DynamismInferenceTest, ParameterIsDynamic) {
171 XlaBuilder b(TestName());
172 auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
173
174 auto value = ComputeDynamismScalar(computation, &b);
175 ASSERT_TRUE(value.ok()) << value.status();
176 // A parameter is considered dynamic.
177 EXPECT_EQ(value.ValueOrDie(), true);
178 }
179
TEST_F(DynamismInferenceTest,UnaryOpKeepsDynamism)180 TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) {
181 XlaBuilder b(TestName());
182 auto c = ConstantR0<int32>(&b, 42);
183 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
184
185 auto neg0 = Neg(c);
186 auto neg1 = Neg(p);
187 auto tuple_2 = Tuple(&b, {neg0, neg1});
188 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
189 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
190 }
191
TEST_F(DynamismInferenceTest,ParameterWithToken)192 TEST_F(DynamismInferenceTest, ParameterWithToken) {
193 // Test that token shape can be handled in a parameter.
194 XlaBuilder b(TestName());
195 auto p =
196 Parameter(&b, 0,
197 ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape(),
198 ShapeUtil::MakeScalarShape(S32)}),
199 "p0");
200 EXPECT_EQ(ComputeDynamismScalar(p, &b, {0}).ValueOrDie(), true);
201 EXPECT_EQ(ComputeDynamismScalar(p, &b, {1}).ValueOrDie(), true);
202 }
203
TEST_F(DynamismInferenceTest,BinaryOpsOrsDynamism)204 TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) {
205 XlaBuilder b(TestName());
206 auto c = ConstantR0<int32>(&b, 42);
207 auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0");
208
209 // Static value + static value = static
210 auto add1 = Add(c, c);
211 // Dynamic value + dynamic value = dynamic
212 auto add2 = Add(p, c);
213 auto tuple_2 = Tuple(&b, {add1, add2});
214 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), false);
215 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), true);
216 }
217
TEST_F(DynamismInferenceTest,GetDimensionSize)218 TEST_F(DynamismInferenceTest, GetDimensionSize) {
219 XlaBuilder b(TestName());
220 // param = Param([<=2, 3])
221 // get_dimension_size(param, 0) is dynamic
222 // get_dimension_size(param, 1) is static
223 auto p =
224 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
225
226 auto gds0 = GetDimensionSize(p, 0);
227 auto gds1 = GetDimensionSize(p, 1);
228 auto tuple_2 = Tuple(&b, {gds0, gds1});
229 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {0}).ValueOrDie(), true);
230 EXPECT_EQ(ComputeDynamismScalar(tuple_2, &b, {1}).ValueOrDie(), false);
231 }
232
TEST_F(DynamismInferenceTest,DynamicSliceWithConstantOperands)233 TEST_F(DynamismInferenceTest, DynamicSliceWithConstantOperands) {
234 XlaBuilder b(TestName());
235
236 auto constant = ConstantR1<int32>(&b, {0, 1, 2, 3});
237 auto slice_start = ConstantR0(&b, 1);
238 auto dynamic_slice = DynamicSlice(constant, {slice_start}, {1});
239 EXPECT_FALSE(
240 ComputeDynamismLiteral(dynamic_slice, &b).ValueOrDie().Get<bool>({0}));
241 }
242
TEST_F(DynamismInferenceTest,GatherWithCommonParent)243 TEST_F(DynamismInferenceTest, GatherWithCommonParent) {
244 XlaBuilder b(TestName());
245 // Test the analysis on a gather where first operand and second operand have
246 // common parents.
247 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
248
249 auto operand1 = Parameter(&b, 0, indices_shape, "p1");
250 auto operand2 = Parameter(&b, 1, indices_shape, "p2");
251 auto indices = Sub(operand1, operand2);
252 GatherDimensionNumbers dim_numbers;
253 dim_numbers.add_offset_dims(1);
254 dim_numbers.add_start_index_map(0);
255 dim_numbers.set_index_vector_dim(1);
256 auto gather = Gather(operand1, indices, dim_numbers, {1});
257 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
258 EXPECT_TRUE(
259 ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
260 }
261
TEST_F(DynamismInferenceTest,GatherWithConstantParent)262 TEST_F(DynamismInferenceTest, GatherWithConstantParent) {
263 XlaBuilder b(TestName());
264 // Test the analysis on a gather.
265 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
266 auto data_operand = ConstantR1<int32>(&b, {1, 2});
267 auto indices = ConstantR1<int32>(&b, {1, 2});
268 GatherDimensionNumbers dim_numbers;
269 dim_numbers.add_offset_dims(1);
270 dim_numbers.add_start_index_map(0);
271 dim_numbers.set_index_vector_dim(1);
272 auto gather = Gather(data_operand, indices, dim_numbers, {1});
273 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
274 // Everything is constant, result is also contant.
275 EXPECT_FALSE(
276 ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
277 }
278
TEST_F(DynamismInferenceTest,GatherWithSharedConstantParent)279 TEST_F(DynamismInferenceTest, GatherWithSharedConstantParent) {
280 XlaBuilder b(TestName());
281 // Test the analysis on a gather.
282 Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
283 auto operand1 = ConstantR1<int32>(&b, {1, 2});
284 auto operand2 = ConstantR1<int32>(&b, {1, 2});
285 auto indices = Sub(operand1, operand2);
286 GatherDimensionNumbers dim_numbers;
287 dim_numbers.add_offset_dims(1);
288 dim_numbers.add_start_index_map(0);
289 dim_numbers.set_index_vector_dim(1);
290 auto gather = Gather(operand1, indices, dim_numbers, {1});
291 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
292 // Everything is constant, result is also contant.
293 EXPECT_FALSE(
294 ComputeDynamismLiteral(gather, &b).ValueOrDie().Get<bool>({0, 0}));
295 }
296
TEST_F(DynamismInferenceTest,InferThroughPad)297 TEST_F(DynamismInferenceTest, InferThroughPad) {
298 XlaBuilder b(TestName());
299 // Test the analysis on a gather.
300 auto operand1 = ConstantR1<int32>(&b, {1, 2});
301 auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
302 PaddingConfig padding_config;
303 padding_config.add_dimensions()->set_edge_padding_high(1);
304 // After pad the value is [constant, constant, parameter].
305 auto pad = Pad(operand1, parameter, padding_config);
306 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
307 // Everything is constant, result is also contant.
308 EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({0}));
309 EXPECT_FALSE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({1}));
310 EXPECT_TRUE(ComputeDynamismLiteral(pad, &b).ValueOrDie().Get<bool>({2}));
311 }
312
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreSame)313 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreSame) {
314 // The result of following conditional is static.
315 // pred = .. # a dynamic value
316 // if (pred) {
317 // return (1) # both branches return the same value
318 // } else {
319 // return (1)
320 // }
321 //
322
323 auto s32_shape = ShapeUtil::MakeShape(S32, {});
324 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
325 XlaBuilder true_builder("true");
326 Parameter(&true_builder, 0, s32_shape, "cond_param");
327 Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
328 auto true_computation = true_builder.Build().ValueOrDie();
329
330 XlaBuilder false_builder("false");
331 Parameter(&false_builder, 0, s32_shape, "cond_param");
332 Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 1)});
333 auto false_computation = false_builder.Build().ValueOrDie();
334
335 XlaBuilder b(TestName());
336 auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
337 auto constant = ConstantR0<int32>(&b, 0);
338 auto cond = Conditional(parameter, constant, true_computation, constant,
339 false_computation);
340 auto gte = GetTupleElement(cond, 0);
341 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
342 // Result is not dynamic.
343 EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
344 }
345
TEST_F(DynamismInferenceTest,InferThroughConditionalBranchesAreNotSame)346 TEST_F(DynamismInferenceTest, InferThroughConditionalBranchesAreNotSame) {
347 // The result of following conditional is dynamic.
348 // pred = .. # a dynamic value
349 // if (pred) {
350 // return (1) # These two branches return different values.
351 // } else {
352 // return (2)
353 // }
354 //
355
356 auto s32_shape = ShapeUtil::MakeShape(S32, {});
357 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
358 XlaBuilder true_builder("true");
359 Parameter(&true_builder, 0, s32_shape, "cond_param");
360 Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 1)});
361 auto true_computation = true_builder.Build().ValueOrDie();
362
363 XlaBuilder false_builder("false");
364 Parameter(&false_builder, 0, s32_shape, "cond_param");
365 Tuple(&false_builder, {ConstantR0<int32>(&false_builder, 2)});
366 auto false_computation = false_builder.Build().ValueOrDie();
367
368 XlaBuilder b(TestName());
369 auto parameter = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "p0");
370 auto constant = ConstantR0<int32>(&b, 0);
371 auto cond = Conditional(parameter, constant, true_computation, constant,
372 false_computation);
373 auto gte = GetTupleElement(cond, 0);
374 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
375 // Result is dynamic.
376 EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
377 }
378
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantTrueBranch)379 TEST_F(DynamismInferenceTest, InferThroughConditionalPredIsConstantTrueBranch) {
380 // The result of following conditional is static.
381 // pred = true
382 // if (pred) {
383 // return (1)
384 // } else {
385 // return (..dynamic_value...)
386 // }
387 //
388
389 auto s32_shape = ShapeUtil::MakeShape(S32, {});
390 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
391 XlaBuilder true_builder("true");
392 Parameter(&true_builder, 0, s32_shape, "cond_param");
393 Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
394 auto true_computation = true_builder.Build().ValueOrDie();
395
396 XlaBuilder false_builder("false");
397 Tuple(&false_builder,
398 {Parameter(&false_builder, 0, s32_shape, "cond_param")});
399 auto false_computation = false_builder.Build().ValueOrDie();
400
401 XlaBuilder b(TestName());
402 auto pred = ConstantR0<bool>(&b, true);
403 auto constant = ConstantR0<int32>(&b, 0);
404 auto cond = Conditional(pred, constant, true_computation, constant,
405 false_computation);
406 auto gte = GetTupleElement(cond, 0);
407 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
408 // Result is not dynamic.
409 EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
410 }
411
TEST_F(DynamismInferenceTest,InferThroughConditionalPredIsConstantFalseBranch)412 TEST_F(DynamismInferenceTest,
413 InferThroughConditionalPredIsConstantFalseBranch) {
414 // The result of following conditional is dynamic.
415 // pred = false
416 // if (pred) {
417 // return (1)
418 // } else {
419 // return (..dynamic_value...)
420 // }
421 //
422
423 auto s32_shape = ShapeUtil::MakeShape(S32, {});
424 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
425 XlaBuilder true_builder("true");
426 Parameter(&true_builder, 0, s32_shape, "cond_param");
427 Tuple(&true_builder, {ConstantR0<int32>(&true_builder, 0)});
428 auto true_computation = true_builder.Build().ValueOrDie();
429
430 XlaBuilder false_builder("false");
431 Tuple(&false_builder,
432 {Parameter(&false_builder, 0, s32_shape, "cond_param")});
433 auto false_computation = false_builder.Build().ValueOrDie();
434
435 XlaBuilder b(TestName());
436 auto param = Parameter(&b, 0, s32_shape, "param");
437 auto pred = ConstantR0<bool>(&b, false);
438 auto constant = ConstantR0<int32>(&b, 0);
439 auto cond =
440 Conditional(pred, constant, true_computation, param, false_computation);
441 auto gte = GetTupleElement(cond, 0);
442 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
443 // Result is dynamic.
444 EXPECT_TRUE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
445 }
446
TEST_F(DynamismInferenceTest,ArgumentForwardingNestedTuple)447 TEST_F(DynamismInferenceTest, ArgumentForwardingNestedTuple) {
448 // The result of following conditional is considered static.
449 // pred = .. dynamic value..
450 //
451 // op = 1
452 // if (pred) {
453 // if (pred) {
454 // return op
455 // } else {
456 // return op
457 // }
458 // } else {
459 // if (pred) {
460 // return op
461 // } else {
462 // return op
463 // }
464 // }
465 //
466 auto pred_shape = ShapeUtil::MakeShape(PRED, {});
467 auto s32_shape = ShapeUtil::MakeShape(S32, {});
468 auto tuple_shape = ShapeUtil::MakeTupleShape({pred_shape, s32_shape});
469 auto cond_shape = ShapeUtil::MakeTupleShape({s32_shape});
470 XlaBuilder inner_true_builder("inner_true");
471 Parameter(&inner_true_builder, 0, s32_shape, "cond_param");
472 Tuple(&inner_true_builder, {ConstantR0<int32>(&inner_true_builder, 0)});
473 auto inner_true_computation = inner_true_builder.Build().ValueOrDie();
474
475 XlaBuilder inner_false_builder("inner_false");
476 Tuple(&inner_false_builder,
477 {Parameter(&inner_false_builder, 0, s32_shape, "cond_param")});
478 auto inner_false_computation = inner_false_builder.Build().ValueOrDie();
479
480 XlaBuilder true_builder("true");
481 {
482 auto param = Parameter(&true_builder, 0, tuple_shape, "param");
483 auto op = GetTupleElement(param, 1);
484 auto pred = GetTupleElement(param, 0);
485 Conditional(pred, op, inner_true_computation, op, inner_false_computation);
486 }
487 auto true_computation = true_builder.Build().ValueOrDie();
488 XlaBuilder false_builder("false");
489 {
490 auto param = Parameter(&false_builder, 0, tuple_shape, "param");
491 auto op = GetTupleElement(param, 1);
492 auto pred = GetTupleElement(param, 0);
493 Conditional(pred, op, inner_true_computation, op, inner_false_computation);
494 }
495 auto false_computation = false_builder.Build().ValueOrDie();
496 XlaBuilder b(TestName());
497 auto constant = ConstantR0<int32>(&b, 0);
498 auto pred = Parameter(&b, 0, pred_shape, "param");
499 auto param = Tuple(&b, {pred, constant});
500 auto cond =
501 Conditional(pred, param, true_computation, param, false_computation);
502 auto gte = GetTupleElement(cond, 0);
503 ASSERT_TRUE(b.first_error().ok()) << b.first_error().error_message();
504 // Result is static.
505 EXPECT_FALSE(ComputeDynamismLiteral(gte, &b).ValueOrDie().Get<bool>({}));
506 }
507
508 class UpperBoundInferenceTest : public ValueInferenceTest {
509 public:
UpperBoundInferenceTest(se::Platform * platform=nullptr)510 explicit UpperBoundInferenceTest(se::Platform* platform = nullptr)
511 : platform_(platform) {}
512
ComputeUpperBoundLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)513 StatusOr<OptionalLiteral> ComputeUpperBoundLiteral(
514 XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
515 ValueInference value_inference(builder);
516 TF_ASSIGN_OR_RETURN(auto literal,
517 value_inference.AnalyzeConstant(
518 operand, ValueInferenceMode::kUpperBound));
519 return literal;
520 }
521
522 se::Platform* platform_;
523 };
524
TEST_F(UpperBoundInferenceTest,GetDimensionSize)525 TEST_F(UpperBoundInferenceTest, GetDimensionSize) {
526 XlaBuilder b(TestName());
527 auto p =
528 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
529
530 auto gds0 = GetDimensionSize(p, 0);
531 auto gds1 = GetDimensionSize(p, 1);
532 auto tuple_2 = Tuple(&b, {gds0, gds1});
533 EXPECT_EQ(
534 ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {0}),
535 2);
536 EXPECT_EQ(
537 ComputeUpperBoundLiteral(tuple_2, &b).ValueOrDie().Get<int32>({}, {1}),
538 3);
539 }
540
TEST_F(UpperBoundInferenceTest,GetDimensionSizeSub)541 TEST_F(UpperBoundInferenceTest, GetDimensionSizeSub) {
542 XlaBuilder b(TestName());
543 auto p =
544 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
545
546 // The range of the first dimension is [0, 2]
547 auto gds0 = GetDimensionSize(p, 0);
548 // The range of the second dimension is [3, 3]
549 auto gds1 = GetDimensionSize(p, 1);
550 // Upper bound of `second_dimension - first_dimension` is 3 - 0 = 3
551 auto sub = Sub(gds1, gds0);
552 EXPECT_EQ(ComputeUpperBoundLiteral(sub, &b).ValueOrDie().Get<int32>({}), 3);
553 }
554
TEST_F(UpperBoundInferenceTest,GetDimensionSizeDiv)555 TEST_F(UpperBoundInferenceTest, GetDimensionSizeDiv) {
556 XlaBuilder b(TestName());
557 auto p =
558 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "p0");
559 // The range of the first dimension is [0, 2]
560 auto gds0 = GetDimensionSize(p, 0);
561 // The range of the second dimension is [3, 3]
562 auto gds1 = GetDimensionSize(p, 1);
563 // Upper bound of `second_dimension / first_dimension` is 3 / 1 = 3. Notice we
564 // don't use 0 as the lower bound as it would create divide-by-zero error.
565 auto div = Div(gds1, gds0);
566 EXPECT_EQ(ComputeUpperBoundLiteral(div, &b).ValueOrDie().Get<int32>({}), 3);
567 }
568
TEST_F(UpperBoundInferenceTest,SumSubtract)569 TEST_F(UpperBoundInferenceTest, SumSubtract) {
570 // If x = a, y = b - a
571 // upperbound(x + y) should be upperbound(b)
572 XlaBuilder b(TestName());
573 auto p =
574 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
575 // The range of the first dimension is [0, 2]
576 auto gds0 = GetDimensionSize(p, 0);
577 // The range of the second dimension is [0, 3]
578 auto gds1 = GetDimensionSize(p, 1);
579 auto sub = Sub(gds1, gds0);
580 auto add = Add(sub, gds0);
581 EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3);
582 auto add2 = Add(gds1, gds0);
583 // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
584 auto add3 = Add(sub, add2);
585 EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32>({}), 6);
586 }
587
TEST_F(UpperBoundInferenceTest,SumSubtractWithDataShuffling)588 TEST_F(UpperBoundInferenceTest, SumSubtractWithDataShuffling) {
589 // Similar to the test above, but with some data shuffling ops in it
590 // (broadcast, slice, reshape, identity convert, etc).
591 XlaBuilder b(TestName());
592 auto p =
593 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
594 // The range of the first dimension is [0, 2]
595 auto gds0 = GetDimensionSize(p, 0);
596 // The range of the second dimension is [0, 3]
597 auto gds1 = GetDimensionSize(p, 1);
598 auto broadcast = Broadcast(gds0, {1, 10});
599 auto convert = ConvertElementType(broadcast, S32); // Identity convert.
600 auto slice = SliceInDim(convert, /*start_index=*/0, /*limit_index=*/1,
601 /*stride=*/1, /*dimno=*/1);
602 gds0 = Reshape(slice, {});
603 auto sub = Sub(gds1, gds0);
604 auto add = Add(sub, gds0);
605 EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3);
606 auto add2 = Add(gds1, gds0);
607 // upperbound(gds1 - gds0 + gds1 + gds0) ==> upperbound(2 * gds1)
608 auto add3 = Add(sub, add2);
609 EXPECT_EQ(ComputeUpperBoundLiteral(add3, &b).ValueOrDie().Get<int32>({}), 6);
610 }
611
TEST_F(UpperBoundInferenceTest,SumSubtractEquivalentGetDimensionSize)612 TEST_F(UpperBoundInferenceTest, SumSubtractEquivalentGetDimensionSize) {
613 XlaBuilder b(TestName());
614 auto p =
615 Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, true}), "p0");
616 // The range of the first dimension is [0, 2]
617 auto gds0 = GetDimensionSize(p, 0);
618 // The range of the second dimension is [0, 3]
619 auto gds1 = GetDimensionSize(p, 1);
620 // gds2 is equivalent to gds0
621 auto gds2 = GetDimensionSize(p, 0);
622 auto sub = Sub(gds1, gds2);
623 auto add = Add(sub, gds0);
624 // upperbound(gds0 + gds1 - gds2) is equal to upperbound(gds1) if gds0 ==
625 // gds2.
626 EXPECT_EQ(ComputeUpperBoundLiteral(add, &b).ValueOrDie().Get<int32>({}), 3);
627 }
628
TEST_F(UpperBoundInferenceTest,ParamCantInferBound)629 TEST_F(UpperBoundInferenceTest, ParamCantInferBound) {
630 // We can infer a parameter's dimension's bound, but not the parameter value's
631 // bound.
632 XlaBuilder b(TestName());
633 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2}, {true}), "p0");
634 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}, {}), "p1");
635 auto gds = GetDimensionSize(p0, 0);
636 auto sub = Div(gds, p1);
637 EXPECT_FALSE(ComputeUpperBoundLiteral(sub, &b)
638 .ValueOrDie()
639 .Get<int32>({})
640 .has_value());
641 }
642
TEST_F(UpperBoundInferenceTest,KeyValueSort)643 TEST_F(UpperBoundInferenceTest, KeyValueSort) {
644 XlaBuilder comparator_b("comparator");
645 auto p0 = Parameter(&comparator_b, 0, ShapeUtil::MakeShape(S32, {}), "p0");
646 auto p1 = Parameter(&comparator_b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
647 Parameter(&comparator_b, 2, ShapeUtil::MakeShape(S32, {}), "p2");
648 Parameter(&comparator_b, 3, ShapeUtil::MakeShape(S32, {}), "p3");
649 Compare(p0, p1, ComparisonDirection::kGe);
650 TF_ASSERT_OK_AND_ASSIGN(auto comparator, comparator_b.Build());
651
652 int64_t elem_count = 17;
653 XlaBuilder b(TestName());
654 auto param = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {elem_count}), "p0");
655 auto iota = Iota(&b, S32, elem_count);
656 auto sort = Sort({param, iota}, comparator);
657 auto gte = GetTupleElement(sort, 1);
658
659 for (int64_t i = 0; i < elem_count; ++i) {
660 auto result_first_elem =
661 ComputeUpperBoundLiteral(gte, &b).ValueOrDie().Get<int32>({i});
662 // We can infer the bound of sort.
663 EXPECT_TRUE(result_first_elem.has_value());
664 // The bound of the sort result is the max value in the input.
665 EXPECT_EQ(result_first_elem.value(), elem_count - 1);
666 }
667 }
668
669 class ConstValueInferenceTest : public ValueInferenceTest {
670 public:
ConstValueInferenceTest(se::Platform * platform=nullptr)671 explicit ConstValueInferenceTest(se::Platform* platform = nullptr)
672 : platform_(platform) {}
673
ComputeConstantValueLiteral(XlaOp operand,XlaBuilder * builder,Layout * output_layout=nullptr)674 StatusOr<OptionalLiteral> ComputeConstantValueLiteral(
675 XlaOp operand, XlaBuilder* builder, Layout* output_layout = nullptr) {
676 ValueInference value_inference(builder);
677 TF_ASSIGN_OR_RETURN(auto literal, value_inference.AnalyzeConstant(
678 operand, ValueInferenceMode::kValue));
679 return literal;
680 }
681
682 se::Platform* platform_;
683 };
684
TEST_F(ConstValueInferenceTest,ConstValuePassThroughSetBound)685 TEST_F(ConstValueInferenceTest, ConstValuePassThroughSetBound) {
686 XlaBuilder b(TestName());
687 auto p0 = ConstantR0<int32>(&b, 32);
688 Shape shape = ShapeUtil::MakeShape(S32, {});
689 xla::Literal dynamism = xla::LiteralUtil::CreateR0<bool>(false);
690 xla::Literal bound = xla::LiteralUtil::CreateR0<int32>(32);
691 xla::Literal tuple =
692 xla::LiteralUtil::MakeTupleOwned(std::move(bound), std::move(dynamism));
693 auto set_bound =
694 CustomCall(&b, "SetBound", {p0}, shape, "", false, {}, &tuple);
695 auto result =
696 ComputeConstantValueLiteral(set_bound, &b).ValueOrDie().Get<int32>({});
697 EXPECT_TRUE(result.has_value());
698 EXPECT_EQ(result.value(), 32);
699 }
700
701 } // namespace
702 } // namespace xla
703