1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_computation.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30 namespace {
31
32 using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);
33
CreateScalarComputation(const string & name,PrimitiveType type,XlaBuilder * builder,XlaOpGenerator generator)34 XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
35 XlaBuilder* builder,
36 XlaOpGenerator generator) {
37 std::unique_ptr<XlaBuilder> b;
38 if (type == PRED) {
39 b = builder->CreateSubBuilder(name);
40 } else {
41 b = builder->CreateSubBuilder(
42 absl::StrCat(name, "_", PrimitiveType_Name(type)));
43 }
44
45 const Shape scalar = ShapeUtil::MakeShape(type, {});
46 auto lhs = Parameter(b.get(), 0, scalar, "lhs");
47 auto rhs = Parameter(b.get(), 1, scalar, "rhs");
48 generator(b.get(), lhs, rhs);
49 return b->BuildAndNoteError();
50 }
51
52 } // namespace
53
CreateScalarAddComputation(PrimitiveType type,XlaBuilder * builder)54 XlaComputation CreateScalarAddComputation(PrimitiveType type,
55 XlaBuilder* builder) {
56 return CreateScalarComputation(
57 "add", type, builder,
58 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
59 return Add(lhs, rhs);
60 });
61 }
62
CreateScalarMultiplyComputation(PrimitiveType type,XlaBuilder * builder)63 XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
64 XlaBuilder* builder) {
65 return CreateScalarComputation(
66 "mul", type, builder,
67 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
68 return Mul(lhs, rhs);
69 });
70 }
71
CreateScalarGeComputation(PrimitiveType type,XlaBuilder * builder)72 XlaComputation CreateScalarGeComputation(PrimitiveType type,
73 XlaBuilder* builder) {
74 return CreateScalarComputation("ge", type, builder,
75 [](XlaBuilder* b, const XlaOp& lhs,
76 const XlaOp& rhs) { return Ge(lhs, rhs); });
77 }
78
CreateScalarMaxComputation(PrimitiveType type,XlaBuilder * builder)79 XlaComputation CreateScalarMaxComputation(PrimitiveType type,
80 XlaBuilder* builder) {
81 return CreateScalarComputation(
82 "max", type, builder,
83 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
84 return Max(lhs, rhs);
85 });
86 }
87
CreateScalarMinComputation(PrimitiveType type,XlaBuilder * builder)88 XlaComputation CreateScalarMinComputation(PrimitiveType type,
89 XlaBuilder* builder) {
90 return CreateScalarComputation(
91 "min", type, builder,
92 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
93 return Min(lhs, rhs);
94 });
95 }
96
CreateScalarAndComputation(PrimitiveType type,XlaBuilder * builder)97 XlaComputation CreateScalarAndComputation(PrimitiveType type,
98 XlaBuilder* builder) {
99 return CreateScalarComputation(
100 "and", type, builder,
101 [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
102 return And(lhs, rhs);
103 });
104 }
105
CreateScalarOrComputation(PrimitiveType type,XlaBuilder * builder)106 XlaComputation CreateScalarOrComputation(PrimitiveType type,
107 XlaBuilder* builder) {
108 return CreateScalarComputation("or", type, builder,
109 [](XlaBuilder* b, const XlaOp& lhs,
110 const XlaOp& rhs) { return Or(lhs, rhs); });
111 }
112
Any(XlaOp predicates)113 XlaOp Any(XlaOp predicates) {
114 XlaBuilder* builder = predicates.builder();
115 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
116 auto f = ConstantR0<bool>(builder, false);
117 XlaComputation logical_or = CreateScalarOrComputation(PRED, builder);
118 TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
119 builder->GetShape(predicates));
120 std::vector<int64> all_dimensions(predicates_shape.rank());
121 std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
122 return Reduce(predicates, f, logical_or, all_dimensions);
123 });
124 }
125
126 namespace {
127
ArgMinMax(XlaOp input,PrimitiveType output_type,int axis,bool is_min)128 XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min) {
129 XlaBuilder* builder = input.builder();
130 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
131 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
132 XlaOp init_value;
133 XlaComputation reducer;
134 if (is_min) {
135 init_value = MaxValue(builder, input_shape.element_type());
136 reducer = CreateScalarMinComputation(input_shape.element_type(), builder);
137 } else {
138 init_value = MinValue(builder, input_shape.element_type());
139 reducer = CreateScalarMaxComputation(input_shape.element_type(), builder);
140 }
141
142 XlaOp input_max = Reduce(input, init_value, reducer,
143 /*dimensions_to_reduce=*/{axis});
144 std::vector<int64> broadcast_dims(input_shape.rank() - 1);
145 std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
146 std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
147 // Compute a mask that has 1s for elements equal to the maximum.
148 XlaOp partial_mask =
149 ConvertElementType(Eq(input, input_max, broadcast_dims), output_type);
150
151 // In order to make identity elements for a bitwise And, we:
152 // Left shift the 1 to the leftmost bit, yielding 0x10...0
153 // Arithmetic right shift the 1 back to the rightmost bit, yielding
154 // 0xFF...F
155 int32 bits_in_type =
156 ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
157 XlaOp shift_amount = ConstantR0WithType(builder, output_type, bits_in_type);
158 XlaOp full_mask = ShiftRightArithmetic(
159 ShiftLeft(partial_mask, shift_amount), shift_amount);
160
161 // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
162 // index.
163
164 const int64 axis_size = ShapeUtil::GetDimension(input_shape, axis);
165 XlaOp iota = Iota(builder, output_type, axis_size);
166 XlaOp product = And(full_mask, iota, /*broadcast_dimensions=*/{axis});
167
168 // If there are multiple maximum elements, choose the one with the highest
169 // index.
170 return Reduce(product, MinValue(builder, output_type),
171 CreateScalarMaxComputation(output_type, builder),
172 /*dimensions_to_reduce=*/{axis});
173 });
174 }
175
176 } // namespace
177
ArgMax(XlaOp input,PrimitiveType output_type,int axis)178 XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis) {
179 return ArgMinMax(input, output_type, axis, /*is_min=*/false);
180 }
181
ArgMin(XlaOp input,PrimitiveType output_type,int axis)182 XlaOp ArgMin(XlaOp input, PrimitiveType output_type, int axis) {
183 return ArgMinMax(input, output_type, axis, /*is_min=*/true);
184 }
185
186 } // namespace xla
187