1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 
CreateScalarComputation(const string & name,PrimitiveType type,XlaBuilder * builder,XlaOpGenerator generator)31 XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
32                                        XlaBuilder* builder,
33                                        XlaOpGenerator generator) {
34   std::unique_ptr<XlaBuilder> b;
35   if (type == PRED) {
36     b = builder->CreateSubBuilder(name);
37   } else {
38     b = builder->CreateSubBuilder(
39         absl::StrCat(name, "_", PrimitiveType_Name(type)));
40   }
41 
42   const Shape scalar = ShapeUtil::MakeShape(type, {});
43   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
44   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
45   generator(lhs, rhs);
46   return b->BuildAndNoteError();
47 }
48 
CreateScalarAddComputation(PrimitiveType type,XlaBuilder * builder)49 XlaComputation CreateScalarAddComputation(PrimitiveType type,
50                                           XlaBuilder* builder) {
51   return CreateScalarComputation(
52       "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); });
53 }
54 
CreateScalarMultiplyComputation(PrimitiveType type,XlaBuilder * builder)55 XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
56                                                XlaBuilder* builder) {
57   return CreateScalarComputation(
58       "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); });
59 }
60 
CreateScalarGeComputation(PrimitiveType type,XlaBuilder * builder)61 XlaComputation CreateScalarGeComputation(PrimitiveType type,
62                                          XlaBuilder* builder) {
63   return CreateScalarComputation(
64       "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); });
65 }
66 
CreateScalarMaxComputation(PrimitiveType type,XlaBuilder * builder)67 XlaComputation CreateScalarMaxComputation(PrimitiveType type,
68                                           XlaBuilder* builder) {
69   return CreateScalarComputation(
70       "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); });
71 }
72 
CreateScalarMinComputation(PrimitiveType type,XlaBuilder * builder)73 XlaComputation CreateScalarMinComputation(PrimitiveType type,
74                                           XlaBuilder* builder) {
75   return CreateScalarComputation(
76       "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); });
77 }
78 
CreateScalarAndComputation(PrimitiveType type,XlaBuilder * builder)79 XlaComputation CreateScalarAndComputation(PrimitiveType type,
80                                           XlaBuilder* builder) {
81   return CreateScalarComputation(
82       "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); });
83 }
84 
CreateScalarOrComputation(PrimitiveType type,XlaBuilder * builder)85 XlaComputation CreateScalarOrComputation(PrimitiveType type,
86                                          XlaBuilder* builder) {
87   return CreateScalarComputation(
88       "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); });
89 }
90 
CreateScalarIdentityWithZeroComputation(PrimitiveType type,XlaBuilder * builder)91 XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type,
92                                                        XlaBuilder* builder) {
93   XlaComputation reducer =
94       (primitive_util::IsIntegralType(type) || type == PRED)
95           ? CreateScalarOrComputation(type, builder)
96           : CreateScalarAddComputation(type, builder);
97   return reducer;
98 }
99 
Any(XlaOp predicates)100 XlaOp Any(XlaOp predicates) {
101   XlaBuilder* builder = predicates.builder();
102   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
103     auto f = ConstantR0<bool>(builder, false);
104     XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
105     TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
106                         builder->GetShape(predicates));
107     std::vector<int64> all_dimensions(predicates_shape.rank());
108     std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
109     return Reduce(predicates, f, logical_or, all_dimensions);
110   });
111 }
112 
113 namespace {
114 
CreateMinMaxComputation(XlaBuilder * outer_builder,PrimitiveType value_type,PrimitiveType index_type,bool is_min,bool stable,bool tie_low)115 XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder,
116                                        PrimitiveType value_type,
117                                        PrimitiveType index_type, bool is_min,
118                                        bool stable, bool tie_low) {
119   auto sub_builder = outer_builder->CreateSubBuilder("minmax_func");
120   XlaBuilder* b = sub_builder.get();
121   XlaOp lhs_value =
122       Parameter(b, 0, ShapeUtil::MakeShape(value_type, {}), "lhs_value");
123   XlaOp lhs_index =
124       Parameter(b, 1, ShapeUtil::MakeShape(index_type, {}), "lhs_index");
125   XlaOp rhs_value =
126       Parameter(b, 2, ShapeUtil::MakeShape(value_type, {}), "rhs_value");
127   XlaOp rhs_index =
128       Parameter(b, 3, ShapeUtil::MakeShape(index_type, {}), "rhs_index");
129 
130   XlaOp cmp = is_min ? Le(lhs_value, rhs_value) : Ge(lhs_value, rhs_value);
131   XlaOp max = Select(cmp, lhs_value, rhs_value);
132   XlaOp arg_max = Select(cmp, lhs_index, rhs_index);
133   if (stable) {
134     XlaOp eq = Eq(lhs_value, rhs_value);
135     XlaOp tie_id =
136         tie_low ? Min(lhs_index, rhs_index) : Max(lhs_index, rhs_index);
137     arg_max = Select(eq, tie_id, arg_max);
138   }
139   Tuple(b, {max, arg_max});
140   return b->BuildAndNoteError();
141 }
142 
ArgMinMax(XlaOp input,PrimitiveType output_type,int axis,bool is_min,bool stable,bool tie_low)143 XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min,
144                 bool stable, bool tie_low) {
145   XlaBuilder* builder = input.builder();
146   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
147     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
148     XlaOp value_init_value;
149     if (is_min) {
150       value_init_value = MaxValue(builder, input_shape.element_type());
151     } else {
152       value_init_value = MinValue(builder, input_shape.element_type());
153     }
154     int64_t dimension_size = input_shape.dimensions(axis);
155     auto index_type = dimension_size <= INT32_MAX ? S32 : output_type;
156     XlaOp index_init_value = Zero(builder, index_type);
157     auto iota_shape = input_shape;
158     iota_shape.set_element_type(index_type);
159     XlaOp iota = Iota(builder, iota_shape, axis);
160 
161     XlaComputation reducer =
162         CreateMinMaxComputation(builder, input_shape.element_type(), index_type,
163                                 is_min, stable, tie_low);
164     XlaOp max_argmax = Reduce(builder, {input, iota},
165                               {value_init_value, index_init_value}, reducer,
166                               /*dimensions_to_reduce=*/{axis});
167     XlaOp argmax = GetTupleElement(max_argmax, 1);
168     if (index_type != output_type) {
169       argmax = ConvertElementType(argmax, output_type);
170     }
171     return argmax;
172   });
173 }
174 
ArgMinMaxTwoPass(XlaOp input,PrimitiveType output_type,int axis,bool is_min,bool tie_low)175 XlaOp ArgMinMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
176                        bool is_min, bool tie_low) {
177   XlaBuilder* builder = input.builder();
178   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
179     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
180     XlaOp init_value;
181     XlaComputation reducer;
182     if (is_min) {
183       init_value = MaxValue(builder, input_shape.element_type());
184       reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
185     } else {
186       init_value = MinValue(builder, input_shape.element_type());
187       reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
188     }
189 
190     XlaOp iota = Iota(
191         builder, ShapeUtil::ChangeElementType(input_shape, output_type), axis);
192     XlaOp reduced_input = Reduce(input, init_value, reducer,
193                                  /*dimensions_to_reduce=*/{axis});
194     std::vector<int64> broadcast_dims(input_shape.rank() - 1);
195     std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
196     std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
197     if (tie_low) {
198       XlaOp max_idx = MaxValue(builder, output_type);
199       XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
200                                  /*on_true=*/iota,
201                                  /*on_false=*/
202                                  max_idx);
203       return Reduce(select_mask, max_idx,
204                     CreateScalarMinComputation(output_type, builder),
205                     /*dimensions_to_reduce=*/{axis});
206     } else {
207       XlaOp min_idx = MinValue(builder, output_type);
208       XlaOp select_mask = Select(Eq(input, reduced_input, broadcast_dims),
209                                  /*on_true=*/iota,
210                                  /*on_false=*/
211                                  min_idx);
212       return Reduce(select_mask, min_idx,
213                     CreateScalarMaxComputation(output_type, builder),
214                     /*dimensions_to_reduce=*/{axis});
215     }
216   });
217 }
218 }  // namespace
219 
ArgMax(XlaOp input,PrimitiveType output_type,int axis,bool stable,bool tie_low)220 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis, bool stable,
221              bool tie_low) {
222   return ArgMinMax(input, output_type, axis, /*is_min=*/false, stable, tie_low);
223 }
224 
ArgMin(XlaOp input,PrimitiveType output_type,int axis,bool stable,bool tie_low)225 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis, bool stable,
226              bool tie_low) {
227   return ArgMinMax(input, output_type, axis, /*is_min=*/true, stable, tie_low);
228 }
229 
ArgMaxTwoPass(XlaOp input,PrimitiveType output_type,int axis,bool tie_low)230 XlaOp ArgMaxTwoPass(XlaOp input, PrimitiveType output_type, int axis,
231                     bool tie_low) {
232   return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/false, tie_low);
233 }
234 
ArgMinTwoPass(XlaOp input,PrimitiveType output_type,int axis,bool tie_low)235 XlaOp ArgMinTwoPass(XlaOp input, PrimitiveType output_type, int axis,
236                     bool tie_low) {
237   return ArgMinMaxTwoPass(input, output_type, axis, /*is_min=*/true, tie_low);
238 }
239 }  // namespace xla
240