• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/comparators.h"
17 
18 #include <limits>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/primitive_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 
32 namespace xla {
33 namespace {
34 
35 using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span<const int64_t>);
36 
CreateScalarComparisonComputation(const std::string & name,const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder,XlaCompareOp generator)37 XlaComputation CreateScalarComparisonComputation(
38     const std::string& name, const std::vector<PrimitiveType>& operand_types,
39     XlaBuilder* builder, XlaCompareOp generator) {
40   CHECK_NE(operand_types.size(), 0);
41   std::vector<std::optional<XlaCompareOp>> generators(operand_types.size());
42   generators[0] = generator;
43   return CreateScalarComparisonComputation(name, operand_types, generators,
44                                            builder);
45 }
46 }  // namespace
47 
CreateScalarComparisonComputation(const std::string & name,const std::vector<PrimitiveType> & operand_types,const std::vector<std::optional<XlaCompareOp>> & generators,XlaBuilder * builder)48 XlaComputation CreateScalarComparisonComputation(
49     const std::string& name, const std::vector<PrimitiveType>& operand_types,
50     const std::vector<std::optional<XlaCompareOp>>& generators,
51     XlaBuilder* builder) {
52   // Create a default computation where we compare only the first two
53   // parameters of type 'operand_types[0]'.
54   auto b = builder->CreateSubBuilder(name);
55   if (operand_types.empty()) {
56     b->ReportError(InvalidArgument("operand_types should not be empty"));
57     return b->BuildAndNoteError();
58   }
59 
60   CHECK_EQ(operand_types.size(), generators.size());
61   int64_t parameter_count = 0;
62   int64_t last_generator_index = 0;
63   std::vector<XlaOp> lhs_params;
64   std::vector<XlaOp> rhs_params;
65 
66   // For each type in 'operand_types' we create two parameters of this type. The
67   // idea is that this computation can be used by n-ary Sort, and potentially
68   // should support comparing also the other operands of sort. In this default
69   // computation, however, we will not actually use any parameters except the
70   // first two.
71   for (auto operand_type : operand_types) {
72     auto scalar_shape = ShapeUtil::MakeShape(operand_type, {});
73     auto lhs_param = Parameter(b.get(), parameter_count * 2, scalar_shape,
74                                absl::StrCat("p.", parameter_count, ".lhs"));
75     auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape,
76                                absl::StrCat("p.", parameter_count, ".rhs"));
77     lhs_params.emplace_back(lhs_param);
78     rhs_params.emplace_back(rhs_param);
79     if (generators[parameter_count].has_value()) {
80       last_generator_index = parameter_count;
81     }
82     parameter_count++;
83   }
84 
85   CHECK_NE(parameter_count, 0);
86 
87   auto shape_or = b->GetShape(lhs_params[0]);
88   if (!shape_or.ok()) {
89     b->ReportError(shape_or.status());
90     return {};
91   }
92   Shape shape = shape_or.ValueOrDie();
93   shape.set_element_type(PRED);
94   XlaOp param_equal =
95       Broadcast(One(b.get(), shape.element_type()), shape.dimensions());
96   XlaOp result = param_equal;
97 
98   for (int64_t i = 0; i < parameter_count; i++) {
99     if (generators[i].has_value()) {
100       result = Select(param_equal,
101                       generators[i].value()(lhs_params[i], rhs_params[i], {}),
102                       result);
103       if (i != last_generator_index) {
104         param_equal =
105             And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i]));
106       }
107     }
108   }
109 
110   return b->BuildAndNoteError();
111 }
112 
113 // Creates a scalar less-than computation and returns it.
CreateScalarLtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)114 XlaComputation CreateScalarLtComputation(
115     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
116   return CreateScalarComparisonComputation("compare-less-than", operand_types,
117                                            builder, LtTotalOrder);
118 }
119 
120 // Creates a scalar greater-than computation and returns it.
CreateScalarGtComputation(const std::vector<PrimitiveType> & operand_types,XlaBuilder * builder)121 XlaComputation CreateScalarGtComputation(
122     const std::vector<PrimitiveType>& operand_types, XlaBuilder* builder) {
123   return CreateScalarComparisonComputation(
124       "compare-greater-than", operand_types, builder, GtTotalOrder);
125 }
126 
127 }  // namespace xla
128