1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/literal_util.h"
17 #include "tensorflow/compiler/tf2xla/type_util.h"
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/comparators.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/ops_util.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/lib/core/bits.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/tpu/tpu_defs.h"
34
35 namespace tensorflow {
36 namespace {
37
38 using xla::S32;
39 using xla::XlaOp;
40
41 // "Shifts" a rank-1 array one element to the right, inserting a 0 at the
42 // beginning and cutting off the last element of the array.
43 //
44 // That is, transforms [x0, x1, ..., xn] into [0, x0, ..., xn-1].
ShiftElemsRight(XlaOp x)45 StatusOr<XlaOp> ShiftElemsRight(XlaOp x) {
46 xla::XlaBuilder* b = x.builder();
47 StatusOr<xla::Shape> shape = b->GetShape(x);
48 TF_RETURN_IF_ERROR(shape.status());
49 TF_RET_CHECK(shape->dimensions_size() == 1);
50 int64_t n = shape->dimensions(0);
51
52 XlaOp padded = xla::PadInDim(x, xla::Zero(b, shape->element_type()),
53 /*dimno=*/0, /*pad_lo=*/1, /*pad_hi=*/0);
54 return xla::SliceInDim(padded, /*start_index=*/0, /*limit_index=*/n,
55 /*stride=*/1, /*dimno=*/0);
56 }
57
58 // Recursive prefix-sum algorithm.
59 //
60 // - Let the input be an array x.
61 // - Let evens be [x0, x2, ...].
62 // - Let odds be [x1, x3, ...].
63 // - Let combined be evens + odds.
64 // - Let psum = prefix-sum(combined), recursively.
65 //
66 // Then the prefix-sum of x is the interleaving of psum - odds and psum.
67 // Written out, this is:
68 //
69 // [psum[0] - odds[0], psum[0], psum[1] - odds[1], psum[1], ...].
70 //
71 // Requires: `arr` is a 1D S32 array whose length is padded to a power of 2.
72 //
73 // Optimization: Rather than split the input into two slices (evens/odds), we
74 // split it into kNumSlices pieces. The basic algorithm is the same, but this
75 // reduces the number of GPU kernels we have to launch.
76 //
77 // There are much more efficient algorithms to be had! In particular, on GPU
78 // this launches O(log4 n) kernels, but there are efficient algorithms that use
79 // just one kernel, see
80 // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
81 //
82 // Nonetheless, this is much simpler than the algorithm in the paper above, but
83 // also much faster than implementing tf.where by sorting the input.
PrefixSum(XlaOp arr)84 StatusOr<XlaOp> PrefixSum(XlaOp arr) {
85 xla::XlaBuilder* b = arr.builder();
86 StatusOr<xla::Shape> input_shape = b->GetShape(arr);
87 TF_RETURN_IF_ERROR(input_shape.status());
88
89 TF_RET_CHECK(input_shape->dimensions_size() == 1);
90 int64_t n = input_shape->dimensions(0);
91
92 // The original input length must be a power of 2, but we recursively divide
93 // it into kNumSlices chunks. Assuming kNumSlices == 4, this means our
94 // base-case needs to handle n == 1 (original length was a power of 4) or
95 // n == 2 (original size was a power of 2).
96 constexpr int kNumSlices = 4;
97 if (n <= 1) {
98 return arr;
99 }
100 if (n == 2) {
101 TF_ASSIGN_OR_RETURN(XlaOp shifted, ShiftElemsRight(arr));
102 return arr + shifted;
103 }
104 TF_RET_CHECK(n % kNumSlices == 0);
105
106 std::array<XlaOp, kNumSlices> slices;
107 for (int i = 0; i < slices.size(); i++) {
108 slices[i] = xla::Slice(arr, /*start_indices=*/{i}, /*limit_indices=*/{n},
109 /*strides=*/{kNumSlices});
110 }
111
112 XlaOp combined = slices[0];
113 for (int i = 1; i < kNumSlices; ++i) {
114 combined = combined + slices[i];
115 }
116
117 TF_ASSIGN_OR_RETURN(XlaOp psum, PrefixSum(combined));
118
119 std::array<XlaOp, kNumSlices> slices_psummed;
120 slices_psummed[kNumSlices - 1] = psum;
121 for (int i = kNumSlices - 2; i >= 0; --i) {
122 slices_psummed[i] = slices_psummed[i + 1] - slices[i + 1];
123 }
124
125 // Interleave the slices.
126 std::array<XlaOp, kNumSlices> slices_padded;
127 for (int i = 0; i < kNumSlices; ++i) {
128 xla::PaddingConfig padding_config;
129 auto* dim = padding_config.add_dimensions();
130 dim->set_edge_padding_low(i);
131 dim->set_edge_padding_high(kNumSlices - i - 1);
132 dim->set_interior_padding(kNumSlices - 1);
133 slices_padded[i] =
134 xla::Pad(slices_psummed[i], xla::Zero(b, S32), padding_config);
135 }
136
137 XlaOp ret = slices_padded[0];
138 for (int i = 1; i < kNumSlices; ++i) {
139 ret = ret + slices_padded[i];
140 }
141
142 return ret;
143 }
144
145 // prefix-sum works better on CPU/GPU, whereas sort works better on TPU.
ShouldUsePrefixSumImpl(const DeviceType & dt)146 bool ShouldUsePrefixSumImpl(const DeviceType& dt) {
147 absl::string_view t = dt.type_string();
148 return t == DEVICE_CPU_XLA_JIT || t == DEVICE_GPU_XLA_JIT ||
149 t == DEVICE_XLA_CPU || t == DEVICE_XLA_GPU;
150 }
151
CompileWhereWithSort(XlaOpKernelContext * ctx)152 StatusOr<XlaOp> CompileWhereWithSort(XlaOpKernelContext* ctx) {
153 XlaOp condition = ctx->Input(0);
154 TF_ASSIGN_OR_RETURN(xla::Shape input_shape,
155 ctx->builder()->GetShape(condition));
156 auto iota_shape = input_shape;
157 iota_shape.set_element_type(xla::S32);
158
159 int64_t flattened_size = xla::Product(iota_shape.dimensions());
160 XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
161 XlaOp zeros = xla::ZerosLike(reshaped_condition);
162 XlaOp compared = xla::Ne(reshaped_condition, zeros);
163
164 std::vector<XlaOp> to_sort = {compared};
165 std::vector<xla::PrimitiveType> types_to_sort = {xla::PRED};
166 // Generate iota for each dimension, which after combining becomes
167 // indices of each element.
168 for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
169 XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis);
170 XlaOp reshaped = xla::Reshape(iota, {flattened_size});
171 to_sort.push_back(reshaped);
172 types_to_sort.push_back(xla::S32);
173 }
174
175 XlaOp sorted = xla::Sort(
176 to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()),
177 /*dimension=*/0, /*is_stable=*/true);
178 std::vector<XlaOp> to_concat;
179 for (int64_t i = 0; i < iota_shape.rank(); ++i) {
180 XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1);
181 to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}));
182 }
183
184 XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1);
185 result = xla::ConvertElementType(result, ctx->output_xla_type(0));
186
187 // Dynamic padder will handle the dynamic dimension.
188 XlaOp compared_int = xla::ConvertElementType(compared, xla::S32);
189 XlaOp length =
190 xla::ReduceAll(compared_int, xla::Zero(ctx->builder(), xla::S32),
191 xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
192 StatusOr<XlaOp> rebounded_result = xla::SetDimensionSizeWithRebound(
193 &ctx->value_inference(), result, length, 0);
194 if (rebounded_result.ok()) {
195 return rebounded_result;
196 }
197 // TODO(b/207187072): Remove special handling once dynamic reshape can also
198 // be handled.
199 return xla::SetDimensionSize(result, length, 0);
200 }
201
CompileWhereWithPrefixSum(XlaOpKernelContext * ctx)202 StatusOr<XlaOp> CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) {
203 xla::XlaBuilder* b = ctx->builder();
204 XlaOp condition = ctx->Input(0);
205
206 TF_ASSIGN_OR_RETURN(xla::Shape input_shape, b->GetShape(condition));
207
208 int64_t flattened_size = xla::Product(input_shape.dimensions());
209 XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
210 XlaOp zeros = xla::ZerosLike(reshaped_condition);
211 XlaOp preds =
212 xla::ConvertElementType(xla::Ne(reshaped_condition, zeros), S32);
213
214 // Given preds, we compute prefix_sum and out_idx as in the following
215 // example.
216 //
217 // preds = [T, F, F, T, F, T], therefore
218 // prefix_sum = [1, 1, 1, 2, 2, 3], and
219 // out_idxs = [0, ⊥, ⊥, 1, ⊥, 2], where ⊥ is an OOB index.
220 //
221 // We then scatter out_idxs into the result.
222 TF_ASSIGN_OR_RETURN(
223 XlaOp padded_prefix_sum,
224 PrefixSum(xla::PadInDim(
225 preds, xla::Zero(b, S32), /*dimno=*/0, /*pad_lo=*/0,
226 /*pad_hi=*/NextPowerOfTwo(flattened_size) - flattened_size)));
227 XlaOp prefix_sum = xla::SliceInDim(padded_prefix_sum, /*start_index=*/0,
228 /*limit_index=*/flattened_size,
229 /*stride=*/1, /*dimno=*/0);
230
231 // We could compute out_idxs as
232 //
233 // out_idxs[i] = preds[i] ? prefix_sum[i] - 1 : ⊥,
234 //
235 // but it's faster to compute it as
236 //
237 // let ps = prefix_sum in
238 // out_idxs[i] =
239 // if i == 0: ps[i] != 0 ? ps[i] - 1 : ⊥
240 // else: ps[i] != ps[i-1] ? ps[i] - 1 : ⊥
241 //
242 // because we read less memory.
243 XlaOp oob_idx = xla::ConstantR0WithType(b, S32, flattened_size); // ⊥
244 TF_ASSIGN_OR_RETURN(XlaOp prefix_sum_shifted, ShiftElemsRight(prefix_sum));
245 XlaOp out_idxs = xla::Select(xla::Ne(prefix_sum, prefix_sum_shifted),
246 /*on_true=*/prefix_sum - xla::One(b, S32),
247 /*on_false=*/oob_idx);
248 out_idxs = xla::Reshape(out_idxs, {flattened_size, 1});
249
250 // tf.where returns an array of multidimensional indices where the condition
251 // is true. For example:
252 //
253 // input = [
254 // [F, T],
255 // [T, F],
256 // [F, F],
257 // ]
258 //
259 // results in
260 //
261 // output = [
262 // [0,0], [1,0],
263 // ]
264 //
265 // Generate the list
266 //
267 // iotas = [[0,...,0], [0,...,1], ..., [limit_0,...,limit_n]],
268 //
269 // and then scatter iotas[out_idxs] into the output.
270 std::vector<XlaOp> iotas_to_concat;
271 auto iota_shape = input_shape;
272 iota_shape.set_element_type(S32);
273 for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
274 iotas_to_concat.push_back(
275 xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}));
276 }
277 XlaOp iotas = xla::ConcatInDim(b, iotas_to_concat, /*dimension=*/1);
278
279 // Scatter subcomputation. Instead of the usual `return p0 + p1`, simply
280 // does `return p1`, because we just want to overwrite whatever was in the
281 // scatter dest.
282 xla::XlaComputation assn_computation = [&] {
283 std::unique_ptr<xla::XlaBuilder> subb =
284 b->CreateSubBuilder("where_op_scatter_assn");
285 xla::Shape param_shape = xla::ShapeUtil::MakeShape(S32, {});
286 xla::Parameter(subb.get(), 0, param_shape, "p0");
287 xla::Parameter(subb.get(), 1, param_shape, "p1");
288 // Simply return p1, the last op we created.
289 return subb->BuildAndNoteError();
290 }();
291
292 xla::ScatterDimensionNumbers scatter_dnums;
293 scatter_dnums.set_index_vector_dim(1);
294 scatter_dnums.add_inserted_window_dims(0);
295 scatter_dnums.add_scatter_dims_to_operand_dims(0);
296 scatter_dnums.add_update_window_dims(1);
297 XlaOp scattered = xla::Scatter(
298 /*input=*/xla::Zeros(b, /*shape=*/xla::ShapeUtil::MakeShape(
299 S32, {flattened_size, iota_shape.rank()})),
300 /*scatter_indices=*/out_idxs, /*updates=*/iotas,
301 /*update_computation=*/assn_computation, scatter_dnums,
302 /*indices_are_sorted=*/true, /*unique_indices=*/true);
303 scattered = xla::ConvertElementType(scattered, ctx->output_xla_type(0));
304
305 // Now count how many valid elements there are and slice off the tail of
306 // `scattered`.
307 XlaOp num_valid =
308 xla::ReduceAll(xla::ConvertElementType(preds, S32), xla::Zero(b, S32),
309 xla::CreateScalarAddComputation(S32, b));
310 StatusOr<XlaOp> rebounded_result = xla::SetDimensionSizeWithRebound(
311 &ctx->value_inference(), scattered, num_valid, 0);
312 if (rebounded_result.ok()) {
313 return *rebounded_result;
314 }
315 // TODO(b/207187072): Remove special handling once dynamic reshape can also
316 // be handled.
317 return xla::SetDimensionSize(scattered, num_valid, 0);
318 }
319
320 class WhereOp : public XlaOpKernel {
321 public:
WhereOp(OpKernelConstruction * ctx)322 explicit WhereOp(OpKernelConstruction* ctx)
323 : XlaOpKernel(ctx),
324 use_prefix_sum_(ShouldUsePrefixSumImpl(ctx->device_type())) {}
325
Compile(XlaOpKernelContext * ctx)326 void Compile(XlaOpKernelContext* ctx) override {
327 StatusOr<XlaOp> ret;
328 if (use_prefix_sum_) {
329 ret = CompileWhereWithPrefixSum(ctx);
330 } else {
331 ret = CompileWhereWithSort(ctx);
332 }
333 OP_REQUIRES_OK(ctx, ret.status());
334 ctx->SetOutput(0, *ret);
335 }
336
337 private:
338 bool use_prefix_sum_;
339 };
340
341 REGISTER_XLA_OP(Name("Where"), WhereOp);
342
343 } // namespace
344 } // namespace tensorflow
345