• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/literal_util.h"
17 #include "tensorflow/compiler/tf2xla/type_util.h"
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/comparators.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/lib/dynamic_shaped_ops.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/ops_util.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/lib/core/bits.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33 #include "tensorflow/core/tpu/tpu_defs.h"
34 
35 namespace tensorflow {
36 namespace {
37 
38 using xla::S32;
39 using xla::XlaOp;
40 
41 // "Shifts" a rank-1 array one element to the right, inserting a 0 at the
42 // beginning and cutting off the last element of the array.
43 //
44 // That is, transforms [x0, x1, ..., xn] into [0, x0, ..., xn-1].
ShiftElemsRight(XlaOp x)45 StatusOr<XlaOp> ShiftElemsRight(XlaOp x) {
46   xla::XlaBuilder* b = x.builder();
47   StatusOr<xla::Shape> shape = b->GetShape(x);
48   TF_RETURN_IF_ERROR(shape.status());
49   TF_RET_CHECK(shape->dimensions_size() == 1);
50   int64_t n = shape->dimensions(0);
51 
52   XlaOp padded = xla::PadInDim(x, xla::Zero(b, shape->element_type()),
53                                /*dimno=*/0, /*pad_lo=*/1, /*pad_hi=*/0);
54   return xla::SliceInDim(padded, /*start_index=*/0, /*limit_index=*/n,
55                          /*stride=*/1, /*dimno=*/0);
56 }
57 
58 // Recursive prefix-sum algorithm.
59 //
60 // - Let the input be an array x.
61 // - Let evens be [x0, x2, ...].
62 // - Let odds be  [x1, x3, ...].
63 // - Let combined be evens + odds.
64 // - Let psum = prefix-sum(combined), recursively.
65 //
66 // Then the prefix-sum of x is the interleaving of psum - odds and psum.
67 // Written out, this is:
68 //
69 //   [psum[0] - odds[0], psum[0], psum[1] - odds[1], psum[1], ...].
70 //
71 // Requires: `arr` is a 1D S32 array whose length is padded to a power of 2.
72 //
73 // Optimization: Rather than split the input into two slices (evens/odds), we
74 // split it into kNumSlices pieces.  The basic algorithm is the same, but this
75 // reduces the number of GPU kernels we have to launch.
76 //
77 // There are much more efficient algorithms to be had!  In particular, on GPU
78 // this launches O(log4 n) kernels, but there are efficient algorithms that use
79 // just one kernel, see
80 // https://research.nvidia.com/publication/single-pass-parallel-prefix-scan-decoupled-look-back
81 //
82 // Nonetheless, this is much simpler than the algorithm in the paper above, but
83 // also much faster than implementing tf.where by sorting the input.
PrefixSum(XlaOp arr)84 StatusOr<XlaOp> PrefixSum(XlaOp arr) {
85   xla::XlaBuilder* b = arr.builder();
86   StatusOr<xla::Shape> input_shape = b->GetShape(arr);
87   TF_RETURN_IF_ERROR(input_shape.status());
88 
89   TF_RET_CHECK(input_shape->dimensions_size() == 1);
90   int64_t n = input_shape->dimensions(0);
91 
92   // The original input length must be a power of 2, but we recursively divide
93   // it into kNumSlices chunks.  Assuming kNumSlices == 4, this means our
94   // base-case needs to handle n == 1 (original length was a power of 4) or
95   // n == 2 (original size was a power of 2).
96   constexpr int kNumSlices = 4;
97   if (n <= 1) {
98     return arr;
99   }
100   if (n == 2) {
101     TF_ASSIGN_OR_RETURN(XlaOp shifted, ShiftElemsRight(arr));
102     return arr + shifted;
103   }
104   TF_RET_CHECK(n % kNumSlices == 0);
105 
106   std::array<XlaOp, kNumSlices> slices;
107   for (int i = 0; i < slices.size(); i++) {
108     slices[i] = xla::Slice(arr, /*start_indices=*/{i}, /*limit_indices=*/{n},
109                            /*strides=*/{kNumSlices});
110   }
111 
112   XlaOp combined = slices[0];
113   for (int i = 1; i < kNumSlices; ++i) {
114     combined = combined + slices[i];
115   }
116 
117   TF_ASSIGN_OR_RETURN(XlaOp psum, PrefixSum(combined));
118 
119   std::array<XlaOp, kNumSlices> slices_psummed;
120   slices_psummed[kNumSlices - 1] = psum;
121   for (int i = kNumSlices - 2; i >= 0; --i) {
122     slices_psummed[i] = slices_psummed[i + 1] - slices[i + 1];
123   }
124 
125   // Interleave the slices.
126   std::array<XlaOp, kNumSlices> slices_padded;
127   for (int i = 0; i < kNumSlices; ++i) {
128     xla::PaddingConfig padding_config;
129     auto* dim = padding_config.add_dimensions();
130     dim->set_edge_padding_low(i);
131     dim->set_edge_padding_high(kNumSlices - i - 1);
132     dim->set_interior_padding(kNumSlices - 1);
133     slices_padded[i] =
134         xla::Pad(slices_psummed[i], xla::Zero(b, S32), padding_config);
135   }
136 
137   XlaOp ret = slices_padded[0];
138   for (int i = 1; i < kNumSlices; ++i) {
139     ret = ret + slices_padded[i];
140   }
141 
142   return ret;
143 }
144 
145 // prefix-sum works better on CPU/GPU, whereas sort works better on TPU.
ShouldUsePrefixSumImpl(const DeviceType & dt)146 bool ShouldUsePrefixSumImpl(const DeviceType& dt) {
147   absl::string_view t = dt.type_string();
148   return t == DEVICE_CPU_XLA_JIT || t == DEVICE_GPU_XLA_JIT ||
149          t == DEVICE_XLA_CPU || t == DEVICE_XLA_GPU;
150 }
151 
CompileWhereWithSort(XlaOpKernelContext * ctx)152 StatusOr<XlaOp> CompileWhereWithSort(XlaOpKernelContext* ctx) {
153   XlaOp condition = ctx->Input(0);
154   TF_ASSIGN_OR_RETURN(xla::Shape input_shape,
155                       ctx->builder()->GetShape(condition));
156   auto iota_shape = input_shape;
157   iota_shape.set_element_type(xla::S32);
158 
159   int64_t flattened_size = xla::Product(iota_shape.dimensions());
160   XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
161   XlaOp zeros = xla::ZerosLike(reshaped_condition);
162   XlaOp compared = xla::Ne(reshaped_condition, zeros);
163 
164   std::vector<XlaOp> to_sort = {compared};
165   std::vector<xla::PrimitiveType> types_to_sort = {xla::PRED};
166   // Generate iota for each dimension, which after combining becomes
167   // indices of each element.
168   for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
169     XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis);
170     XlaOp reshaped = xla::Reshape(iota, {flattened_size});
171     to_sort.push_back(reshaped);
172     types_to_sort.push_back(xla::S32);
173   }
174 
175   XlaOp sorted = xla::Sort(
176       to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()),
177       /*dimension=*/0, /*is_stable=*/true);
178   std::vector<XlaOp> to_concat;
179   for (int64_t i = 0; i < iota_shape.rank(); ++i) {
180     XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1);
181     to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1}));
182   }
183 
184   XlaOp result = xla::ConcatInDim(ctx->builder(), to_concat, 1);
185   result = xla::ConvertElementType(result, ctx->output_xla_type(0));
186 
187   // Dynamic padder will handle the dynamic dimension.
188   XlaOp compared_int = xla::ConvertElementType(compared, xla::S32);
189   XlaOp length =
190       xla::ReduceAll(compared_int, xla::Zero(ctx->builder(), xla::S32),
191                      xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
192   StatusOr<XlaOp> rebounded_result = xla::SetDimensionSizeWithRebound(
193       &ctx->value_inference(), result, length, 0);
194   if (rebounded_result.ok()) {
195     return rebounded_result;
196   }
197   // TODO(b/207187072): Remove special handling once dynamic reshape can also
198   // be handled.
199   return xla::SetDimensionSize(result, length, 0);
200 }
201 
CompileWhereWithPrefixSum(XlaOpKernelContext * ctx)202 StatusOr<XlaOp> CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) {
203   xla::XlaBuilder* b = ctx->builder();
204   XlaOp condition = ctx->Input(0);
205 
206   TF_ASSIGN_OR_RETURN(xla::Shape input_shape, b->GetShape(condition));
207 
208   int64_t flattened_size = xla::Product(input_shape.dimensions());
209   XlaOp reshaped_condition = xla::Reshape(condition, {flattened_size});
210   XlaOp zeros = xla::ZerosLike(reshaped_condition);
211   XlaOp preds =
212       xla::ConvertElementType(xla::Ne(reshaped_condition, zeros), S32);
213 
214   // Given preds, we compute prefix_sum and out_idx as in the following
215   // example.
216   //
217   //   preds =      [T, F, F, T, F, T], therefore
218   //   prefix_sum = [1, 1, 1, 2, 2, 3], and
219   //   out_idxs   = [0, ⊥, ⊥, 1, ⊥, 2], where ⊥ is an OOB index.
220   //
221   // We then scatter out_idxs into the result.
222   TF_ASSIGN_OR_RETURN(
223       XlaOp padded_prefix_sum,
224       PrefixSum(xla::PadInDim(
225           preds, xla::Zero(b, S32), /*dimno=*/0, /*pad_lo=*/0,
226           /*pad_hi=*/NextPowerOfTwo(flattened_size) - flattened_size)));
227   XlaOp prefix_sum = xla::SliceInDim(padded_prefix_sum, /*start_index=*/0,
228                                      /*limit_index=*/flattened_size,
229                                      /*stride=*/1, /*dimno=*/0);
230 
231   // We could compute out_idxs as
232   //
233   //   out_idxs[i] = preds[i] ? prefix_sum[i] - 1 : ⊥,
234   //
235   // but it's faster to compute it as
236   //
237   //   let ps = prefix_sum in
238   //   out_idxs[i] =
239   //     if i == 0: ps[i] != 0       ? ps[i] - 1 : ⊥
240   //     else:      ps[i] != ps[i-1] ? ps[i] - 1 : ⊥
241   //
242   // because we read less memory.
243   XlaOp oob_idx = xla::ConstantR0WithType(b, S32, flattened_size);  // ⊥
244   TF_ASSIGN_OR_RETURN(XlaOp prefix_sum_shifted, ShiftElemsRight(prefix_sum));
245   XlaOp out_idxs = xla::Select(xla::Ne(prefix_sum, prefix_sum_shifted),
246                                /*on_true=*/prefix_sum - xla::One(b, S32),
247                                /*on_false=*/oob_idx);
248   out_idxs = xla::Reshape(out_idxs, {flattened_size, 1});
249 
250   // tf.where returns an array of multidimensional indices where the condition
251   // is true.  For example:
252   //
253   //    input =  [
254   //      [F, T],
255   //      [T, F],
256   //      [F, F],
257   //    ]
258   //
259   //  results in
260   //
261   //    output = [
262   //      [0,0], [1,0],
263   //    ]
264   //
265   // Generate the list
266   //
267   //   iotas = [[0,...,0], [0,...,1], ..., [limit_0,...,limit_n]],
268   //
269   // and then scatter iotas[out_idxs] into the output.
270   std::vector<XlaOp> iotas_to_concat;
271   auto iota_shape = input_shape;
272   iota_shape.set_element_type(S32);
273   for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) {
274     iotas_to_concat.push_back(
275         xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1}));
276   }
277   XlaOp iotas = xla::ConcatInDim(b, iotas_to_concat, /*dimension=*/1);
278 
279   // Scatter subcomputation.  Instead of the usual `return p0 + p1`, simply
280   // does `return p1`, because we just want to overwrite whatever was in the
281   // scatter dest.
282   xla::XlaComputation assn_computation = [&] {
283     std::unique_ptr<xla::XlaBuilder> subb =
284         b->CreateSubBuilder("where_op_scatter_assn");
285     xla::Shape param_shape = xla::ShapeUtil::MakeShape(S32, {});
286     xla::Parameter(subb.get(), 0, param_shape, "p0");
287     xla::Parameter(subb.get(), 1, param_shape, "p1");
288     // Simply return p1, the last op we created.
289     return subb->BuildAndNoteError();
290   }();
291 
292   xla::ScatterDimensionNumbers scatter_dnums;
293   scatter_dnums.set_index_vector_dim(1);
294   scatter_dnums.add_inserted_window_dims(0);
295   scatter_dnums.add_scatter_dims_to_operand_dims(0);
296   scatter_dnums.add_update_window_dims(1);
297   XlaOp scattered = xla::Scatter(
298       /*input=*/xla::Zeros(b, /*shape=*/xla::ShapeUtil::MakeShape(
299                                S32, {flattened_size, iota_shape.rank()})),
300       /*scatter_indices=*/out_idxs, /*updates=*/iotas,
301       /*update_computation=*/assn_computation, scatter_dnums,
302       /*indices_are_sorted=*/true, /*unique_indices=*/true);
303   scattered = xla::ConvertElementType(scattered, ctx->output_xla_type(0));
304 
305   // Now count how many valid elements there are and slice off the tail of
306   // `scattered`.
307   XlaOp num_valid =
308       xla::ReduceAll(xla::ConvertElementType(preds, S32), xla::Zero(b, S32),
309                      xla::CreateScalarAddComputation(S32, b));
310   StatusOr<XlaOp> rebounded_result = xla::SetDimensionSizeWithRebound(
311       &ctx->value_inference(), scattered, num_valid, 0);
312   if (rebounded_result.ok()) {
313     return *rebounded_result;
314   }
315   // TODO(b/207187072): Remove special handling once dynamic reshape can also
316   // be handled.
317   return xla::SetDimensionSize(scattered, num_valid, 0);
318 }
319 
320 class WhereOp : public XlaOpKernel {
321  public:
WhereOp(OpKernelConstruction * ctx)322   explicit WhereOp(OpKernelConstruction* ctx)
323       : XlaOpKernel(ctx),
324         use_prefix_sum_(ShouldUsePrefixSumImpl(ctx->device_type())) {}
325 
Compile(XlaOpKernelContext * ctx)326   void Compile(XlaOpKernelContext* ctx) override {
327     StatusOr<XlaOp> ret;
328     if (use_prefix_sum_) {
329       ret = CompileWhereWithPrefixSum(ctx);
330     } else {
331       ret = CompileWhereWithSort(ctx);
332     }
333     OP_REQUIRES_OK(ctx, ret.status());
334     ctx->SetOutput(0, *ret);
335   }
336 
337  private:
338   bool use_prefix_sum_;
339 };
340 
341 REGISTER_XLA_OP(Name("Where"), WhereOp);
342 
343 }  // namespace
344 }  // namespace tensorflow
345