• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 
DynamicStridedSlice(XlaOp input,absl::Span<const XlaOp> base_indices,absl::Span<const int64> window_sizes,absl::Span<const int64> strides)29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
30                           absl::Span<const int64> window_sizes,
31                           absl::Span<const int64> strides) {
32   XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
33   if (std::any_of(strides.begin(), strides.end(),
34                   [](int64_t stride) { return stride != 1; })) {
35     sliced_input = Slice(sliced_input, std::vector<int64>(window_sizes.size()),
36                          window_sizes, strides);
37   }
38   return sliced_input;
39 }
40 
SliceInMinorDims(XlaOp x,absl::Span<const int64> start,absl::Span<const int64> end)41 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
42                        absl::Span<const int64> end) {
43   XlaBuilder* builder = x.builder();
44   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
45     TF_RET_CHECK(start.size() == end.size());
46     int64_t n_minor_dims = start.size();
47 
48     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
49 
50     const int64_t n_dims = shape.rank();
51     TF_RET_CHECK(n_minor_dims <= n_dims);
52     auto major_dims = AsInt64Slice(shape.dimensions())
53                           .subspan(
54                               /*pos=*/0,
55                               /*len=*/n_dims - n_minor_dims);
56 
57     // Prepends 0s in the major dim
58     std::vector<int64> padded_start(n_dims, 0);
59     std::copy(start.begin(), start.end(),
60               padded_start.begin() + major_dims.size());
61 
62     // Prepends the shape of the major dims.
63     std::vector<int64> padded_end(n_dims);
64     std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
65     std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
66 
67     std::vector<int64> strides(n_dims, 1);
68     return Slice(x, padded_start, padded_end, strides);
69   });
70 }
71 
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64> start)72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
73   XlaBuilder* builder = x.builder();
74   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
76     const int64_t n_dims = shape.rank();
77     const int64_t start_size = start.size();
78     TF_RET_CHECK(start_size == n_dims);
79 
80     // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
81     std::vector<int32> start_as_int32(start.begin(), start.end());
82     std::vector<XlaOp> start_ops(start.size());
83     for (int i = 0, end = start.size(); i < end; ++i) {
84       start_ops[i] = ConstantR0(builder, start_as_int32[i]);
85     }
86     return DynamicUpdateSlice(x, update, start_ops);
87   });
88 }
89 
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64> start)90 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
91                              absl::Span<const int64> start) {
92   XlaBuilder* builder = x.builder();
93   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
95     const int64_t n_dims = shape.rank();
96     const int64_t n_minor_dims = start.size();
97     TF_RET_CHECK(n_minor_dims <= n_dims);
98     std::vector<int64> padded_start(n_dims, 0);
99     std::copy(start.begin(), start.end(),
100               padded_start.begin() + (n_dims - n_minor_dims));
101     return UpdateSlice(x, update, padded_start);
102   });
103 }
104 
105 namespace {
106 
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)107 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
108                                  absl::Span<const int64> ys) {
109   std::vector<int64> output(xs.size() + ys.size());
110   std::copy(xs.begin(), xs.end(), output.begin());
111   std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
112   return output;
113 }
114 
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)115 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
116     XlaOp x, absl::Span<const XlaOp> starts) {
117   XlaBuilder* builder = x.builder();
118   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
119   const int64_t n_dims = shape.rank();
120   auto zero = ConstantR0<int32>(builder, 0);
121   std::vector<XlaOp> padded_starts(n_dims, zero);
122   for (int i = 0; i < starts.size(); ++i) {
123     padded_starts[n_dims - starts.size() + i] = starts[i];
124   }
125   return padded_starts;
126 }
127 
128 }  // namespace
129 
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64> sizes)130 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
131                               absl::Span<const int64> sizes) {
132   XlaBuilder* builder = x.builder();
133   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
134     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
135     const int64_t n_dims = shape.rank();
136     int64_t n_minor_dims = starts.size();
137     TF_RET_CHECK(n_minor_dims == sizes.size());
138     TF_RET_CHECK(n_minor_dims <= n_dims);
139     auto major_dims = AsInt64Slice(shape.dimensions())
140                           .subspan(
141                               /*pos=*/0,
142                               /*len=*/n_dims - sizes.size());
143     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
144     auto padded_sizes = ConcatVectors(major_dims, sizes);
145     return DynamicSlice(x, padded_starts, padded_sizes);
146   });
147 }
148 
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)149 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
150                                     absl::Span<const XlaOp> starts) {
151   XlaBuilder* builder = x.builder();
152   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
153     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
154     return DynamicUpdateSlice(x, update, padded_starts);
155   });
156 }
157 
TorchGather(XlaOp input,XlaOp index,int64_t dim,bool sparse)158 XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) {
159   XlaBuilder* builder = input.builder();
160   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
161     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
162     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
163     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
164         input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
165       index = ConvertElementType(index, U32);
166       index_shape.set_element_type(U32);
167     }
168     if (index_shape.rank() == 1) {
169       return TorchIndexSelect(input, index, 0);
170     }
171     if (!sparse) {
172       std::vector<int64> index_broadcast_dims;
173       std::vector<int64> input_broadcast_dims;
174       std::vector<int64> sizes;
175       for (int64_t i = 0; i < index_shape.rank(); ++i) {
176         if (i < dim) {
177           input_broadcast_dims.push_back(i);
178           index_broadcast_dims.push_back(i);
179         } else if (i == dim) {
180           sizes.push_back(input_shape.dimensions(i));
181           input_broadcast_dims.push_back(i);
182           index_broadcast_dims.push_back(i + 1);
183         } else {
184           input_broadcast_dims.push_back(i + 1);
185           index_broadcast_dims.push_back(i + 1);
186         }
187         sizes.push_back(index_shape.dimensions(i));
188       }
189       auto mask = Eq(
190           BroadcastInDim(index, sizes, index_broadcast_dims),
191           Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
192                dim));
193       auto masked_input = Select(
194           mask, BroadcastInDim(input, sizes, input_broadcast_dims),
195           Zeros(builder,
196                 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
197       return Reduce(masked_input, Zero(builder, input_shape.element_type()),
198                     CreateScalarIdentityWithZeroComputation(
199                         input_shape.element_type(), builder),
200                     {dim});
201     }
202 
203     ShapeUtil::AppendMajorDimension(1, &index_shape);
204     std::vector<XlaOp> to_concat;
205 
206     to_concat.reserve(input_shape.rank());
207     for (int64_t i = 0; i < input_shape.rank(); ++i) {
208       if (i == dim) {
209         to_concat.push_back(Reshape(index, index_shape.dimensions()));
210       } else {
211         to_concat.push_back(Iota(builder, index_shape, i));
212       }
213     }
214     XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
215     std::vector<int64> slice_sizes(input_shape.rank(), 1);
216     GatherDimensionNumbers gather_dnums;
217     gather_dnums.set_index_vector_dim(input_shape.rank());
218     for (int64_t i = 0; i < input_shape.rank(); ++i) {
219       gather_dnums.add_collapsed_slice_dims(i);
220       gather_dnums.add_start_index_map(i);
221     }
222     return Gather(input, gather_indices, gather_dnums, slice_sizes);
223   });
224 }
225 
TorchScatterDense(XlaOp input,XlaOp index,XlaOp src,int64_t dim,const std::function<XlaOp (XlaOp,XlaOp)> & combiner)226 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim,
227                         const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
228   XlaBuilder* builder = input.builder();
229   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
230     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
231     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
232     std::vector<int64> index_broadcast_dims;
233     std::vector<int64> sizes;
234     for (int64_t i = 0; i < index_shape.rank(); ++i) {
235       if (i < dim) {
236         index_broadcast_dims.push_back(i);
237       } else {
238         if (i == dim) {
239           sizes.push_back(input_shape.dimensions(i));
240         }
241         index_broadcast_dims.push_back(i + 1);
242       }
243       sizes.push_back(index_shape.dimensions(i));
244     }
245     auto mask =
246         Eq(BroadcastInDim(index, sizes, index_broadcast_dims),
247            Iota(builder,
248                 ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
249     auto masked_src =
250         Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims),
251                Zeros(builder,
252                      ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
253 
254     return combiner(
255         input,
256         Reduce(masked_src, Zero(builder, input_shape.element_type()),
257                CreateScalarComputation("reducer", input_shape.element_type(),
258                                        builder, combiner),
259                {dim + 1}));
260   });
261 }
262 
TorchIndexSelect(XlaOp input,XlaOp index,int64_t dim,int64_t batch_dims)263 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim,
264                        int64_t batch_dims) {
265   XlaBuilder* builder = input.builder();
266   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
267     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
268     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
269     if (dim < batch_dims) {
270       return InvalidArgument(
271           "Gather dim must be greater than or equal to the number of batch "
272           "dims");
273     }
274     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
275         input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
276       index = ConvertElementType(index, U32);
277       index_shape.set_element_type(U32);
278     }
279     std::vector<int64> slice_sizes = SpanToVector(input_shape.dimensions());
280     GatherDimensionNumbers gather_dnums;
281     gather_dnums.set_index_vector_dim(index_shape.rank());
282     if (batch_dims > 0) {
283       ShapeUtil::AppendMajorDimension(1, &index_shape);
284       std::vector<XlaOp> to_concat;
285       to_concat.reserve(batch_dims + 1);
286       for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
287         to_concat.push_back(Iota(builder, index_shape, batch_dim));
288       }
289       to_concat.push_back(Reshape(index, index_shape.dimensions()));
290       index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
291     }
292     for (int64_t i = 0; i < input_shape.rank(); ++i) {
293       if (i < batch_dims || i == dim) {
294         slice_sizes[i] = std::min<int64>(slice_sizes[i], 1);
295         gather_dnums.add_collapsed_slice_dims(i);
296         gather_dnums.add_start_index_map(i);
297       } else {
298         if (i < dim) {
299           gather_dnums.add_offset_dims(i);
300         } else {
301           gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
302                                        (1 + batch_dims));
303         }
304       }
305     }
306     return Gather(input, index, gather_dnums, slice_sizes);
307   });
308 }
309 
310 }  // namespace xla
311