1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17
18 #include <limits>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31
32 namespace xla {
33 namespace {
34
35 using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64_t>);
36
CreateScalarComparisonComputation(const std::string & name,const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder,XlaCompareOp generator)37 XlaComputation CreateScalarComparisonComputation(
38 const std::string& name, const std::vector<PrimitiveType>& operand_types,
39 XlaBuilder* builder, XlaCompareOp generator) {
40 CHECK_NE(operand_types.size(), 0);
41 std::vector<std::optional<XlaCompareOp>> generators(operand_types.size());
42 generators[0] = generator;
43 return CreateScalarComparisonComputation(name, operand_types, generators,
44 builder);
45 }
46 } // namespace
47
CreateScalarComparisonComputation(const std::string & name,const std::vector<PrimitiveType> & operand_types,const std::vector<std::optional<XlaCompareOp>> & generators,XlaBuilder * builder)48 XlaComputation CreateScalarComparisonComputation(
49 const std::string& name, const std::vector<PrimitiveType>& operand_types,
50 const std::vector<std::optional<XlaCompareOp>>& generators,
51 XlaBuilder* builder) {
52 // Create a default computation where we compare only the first two
53 // parameters of type 'operand_types[0]'.
54 auto b = builder->CreateSubBuilder(name);
55 if (operand_types.empty()) {
56 b->ReportError(InvalidArgument("operand_types should not be empty"));
57 return b->BuildAndNoteError();
58 }
59
60 CHECK_EQ(operand_types.size(), generators.size());
61 int64_t parameter_count = 0;
62 int64_t last_generator_index = 0;
63 std::vector<XlaOp> lhs_params;
64 std::vector<XlaOp> rhs_params;
65
66 // For each type in 'operand_types' we create two parameters of this type. The
67 // idea is that this computation can be used by n-ary Sort, and potentially
68 // should support comparing also the other operands of sort. In this default
69 // computation, however, we will not actually use any parameters except the
70 // first two.
71 for (auto operand_type : operand_types) {
72 auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
73 auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape,
74 absl::StrCat("p.", parameter_count, ".lhs"));
75 auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
76 absl::StrCat("p.", parameter_count, ".rhs"));
77 lhs_params.emplace_back(lhs_param);
78 rhs_params.emplace_back(rhs_param);
79 if (generators[parameter_count].has_value()) {
80 last_generator_index = parameter_count;
81 }
82 parameter_count++;
83 }
84
85 CHECK_NE(parameter_count, 0);
86
87 auto shape_or = b->GetShape(lhs_params[0]);
88 if (!shape_or.ok()) {
89 b->ReportError(shape_or.status());
90 return {};
91 }
92 Shape shape = shape_or.ValueOrDie();
93 shape.set_element_type(PRED);
94 XlaOp param_equal =
95 Broadcast(One(b.get(), shape.element_type()), shape.dimensions());
96 XlaOp result = param_equal;
97
98 for (int64_t i = 0; i < parameter_count; i++) {
99 if (generators[i].has_value()) {
100 result = Select(param_equal,
101 generators[i].value()(lhs_params[i], rhs_params[i], {}),
102 result);
103 if (i != last_generator_index) {
104 param_equal =
105 And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i]));
106 }
107 }
108 }
109
110 return b->BuildAndNoteError();
111 }
112
113 // Creates a scalar less-than computation and returns it.
CreateScalarLtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)114 XlaComputation CreateScalarLtComputation(
115 const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
116 return CreateScalarComparisonComputation("compare-less-than", operand_types,
117 builder, LtTotalOrder);
118 }
119
120 // Creates a scalar greater-than computation and returns it.
CreateScalarGtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)121 XlaComputation CreateScalarGtComputation(
122 const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
123 return CreateScalarComparisonComputation(
124 "compare-greater-than", operand_types, builder, GtTotalOrder);
125 }
126
127 } // namespace xla
128