• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26 
27 namespace xla {
28 
DynamicStridedSlice(XlaOp input,absl::Span<const XlaOp> base_indices,absl::Span<const int64> window_sizes,absl::Span<const int64> strides)29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
30                           absl::Span<const int64> window_sizes,
31                           absl::Span<const int64> strides) {
32   XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
33   if (std::any_of(strides.begin(), strides.end(),
34                   [](int64 stride) { return stride != 1; })) {
35     sliced_input = Slice(sliced_input, std::vector<int64>(window_sizes.size()),
36                          window_sizes, strides);
37   }
38   return sliced_input;
39 }
40 
SliceInMinorDims(XlaOp x,absl::Span<const int64> start,absl::Span<const int64> end)41 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
42                        absl::Span<const int64> end) {
43   XlaBuilder* builder = x.builder();
44   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
45     TF_RET_CHECK(start.size() == end.size());
46     int64 n_minor_dims = start.size();
47 
48     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
49 
50     const int64 n_dims = shape.rank();
51     TF_RET_CHECK(n_minor_dims <= n_dims);
52     auto major_dims = AsInt64Slice(shape.dimensions())
53                           .subspan(
54                               /*pos=*/0,
55                               /*len=*/n_dims - n_minor_dims);
56 
57     // Prepends 0s in the major dim
58     std::vector<int64> padded_start(n_dims, 0);
59     std::copy(start.begin(), start.end(),
60               padded_start.begin() + major_dims.size());
61 
62     // Prepends the shape of the major dims.
63     std::vector<int64> padded_end(n_dims);
64     std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
65     std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
66 
67     std::vector<int64> strides(n_dims, 1);
68     return Slice(x, padded_start, padded_end, strides);
69   });
70 }
71 
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64> start)72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
73   XlaBuilder* builder = x.builder();
74   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
76     const int64 n_dims = shape.rank();
77     const int64 start_size = start.size();
78     TF_RET_CHECK(start_size == n_dims);
79 
80     // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
81     std::vector<int32> start_as_int32(start.begin(), start.end());
82     std::vector<XlaOp> start_ops(start.size());
83     for (int i = 0, end = start.size(); i < end; ++i) {
84       start_ops[i] = ConstantR0(builder, start_as_int32[i]);
85     }
86     return DynamicUpdateSlice(x, update, start_ops);
87   });
88 }
89 
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64> start)90 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
91                              absl::Span<const int64> start) {
92   XlaBuilder* builder = x.builder();
93   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
95     const int64 n_dims = shape.rank();
96     const int64 n_minor_dims = start.size();
97     TF_RET_CHECK(n_minor_dims <= n_dims);
98     std::vector<int64> padded_start(n_dims, 0);
99     std::copy(start.begin(), start.end(),
100               padded_start.begin() + (n_dims - n_minor_dims));
101     return UpdateSlice(x, update, padded_start);
102   });
103 }
104 
105 namespace {
106 
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)107 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
108                                  absl::Span<const int64> ys) {
109   std::vector<int64> output(xs.size() + ys.size());
110   std::copy(xs.begin(), xs.end(), output.begin());
111   std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
112   return output;
113 }
114 
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)115 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
116     XlaOp x, absl::Span<const XlaOp> starts) {
117   XlaBuilder* builder = x.builder();
118   TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
119   const int64 n_dims = shape.rank();
120   auto zero = ConstantR0<int32>(builder, 0);
121   std::vector<XlaOp> padded_starts(n_dims, zero);
122   for (int i = 0; i < starts.size(); ++i) {
123     padded_starts[n_dims - starts.size() + i] = starts[i];
124   }
125   return padded_starts;
126 }
127 
128 }  // namespace
129 
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64> sizes)130 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
131                               absl::Span<const int64> sizes) {
132   XlaBuilder* builder = x.builder();
133   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
134     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
135     const int64 n_dims = shape.rank();
136     int64 n_minor_dims = starts.size();
137     TF_RET_CHECK(n_minor_dims == sizes.size());
138     TF_RET_CHECK(n_minor_dims <= n_dims);
139     auto major_dims = AsInt64Slice(shape.dimensions())
140                           .subspan(
141                               /*pos=*/0,
142                               /*len=*/n_dims - sizes.size());
143     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
144     auto padded_sizes = ConcatVectors(major_dims, sizes);
145     return DynamicSlice(x, padded_starts, padded_sizes);
146   });
147 }
148 
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)149 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
150                                     absl::Span<const XlaOp> starts) {
151   XlaBuilder* builder = x.builder();
152   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
153     TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
154     return DynamicUpdateSlice(x, update, padded_starts);
155   });
156 }
157 
TorchGather(XlaOp input,XlaOp index,int64 dim,bool sparse)158 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) {
159   XlaBuilder* builder = input.builder();
160   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
161     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
162     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
163     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
164         input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
165       index = ConvertElementType(index, U32);
166       index_shape.set_element_type(U32);
167     }
168     if (index_shape.rank() == 1) {
169       return TorchIndexSelect(input, index, 0);
170     }
171     if (!sparse) {
172       std::vector<int64> index_broadcast_dims;
173       std::vector<int64> input_broadcast_dims;
174       std::vector<int64> sizes;
175       for (int64 i = 0; i < index_shape.rank(); ++i) {
176         if (i < dim) {
177           input_broadcast_dims.push_back(i);
178           index_broadcast_dims.push_back(i);
179         } else if (i == dim) {
180           sizes.push_back(input_shape.dimensions(i));
181           input_broadcast_dims.push_back(i);
182           index_broadcast_dims.push_back(i + 1);
183         } else {
184           input_broadcast_dims.push_back(i + 1);
185           index_broadcast_dims.push_back(i + 1);
186         }
187         sizes.push_back(index_shape.dimensions(i));
188       }
189       auto mask = Eq(
190           BroadcastInDim(index, sizes, index_broadcast_dims),
191           Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
192                dim));
193       auto masked_input = Select(
194           mask, BroadcastInDim(input, sizes, input_broadcast_dims),
195           Zeros(builder,
196                 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
197       return Reduce(masked_input, Zero(builder, input_shape.element_type()),
198                     CreateScalarIdentityWithZeroComputation(
199                         input_shape.element_type(), builder),
200                     {dim});
201     }
202 
203     ShapeUtil::AppendMajorDimension(1, &index_shape);
204     std::vector<XlaOp> to_concat;
205 
206     to_concat.reserve(input_shape.rank());
207     for (int64 i = 0; i < input_shape.rank(); ++i) {
208       if (i == dim) {
209         to_concat.push_back(Reshape(index, index_shape.dimensions()));
210       } else {
211         to_concat.push_back(Iota(builder, index_shape, i));
212       }
213     }
214     XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
215     std::vector<int64> slice_sizes(input_shape.rank(), 1);
216     GatherDimensionNumbers gather_dnums;
217     gather_dnums.set_index_vector_dim(input_shape.rank());
218     for (int64 i = 0; i < input_shape.rank(); ++i) {
219       gather_dnums.add_collapsed_slice_dims(i);
220       gather_dnums.add_start_index_map(i);
221     }
222     return Gather(input, gather_indices, gather_dnums, slice_sizes);
223   });
224 }
225 
TorchScatterDense(XlaOp input,XlaOp index,XlaOp src,int64 dim,const std::function<XlaOp (XlaOp,XlaOp)> & combiner)226 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
227                         const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
228   XlaBuilder* builder = input.builder();
229   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
230     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
231     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
232     std::vector<int64> index_broadcast_dims;
233     std::vector<int64> sizes;
234     for (int64 i = 0; i < index_shape.rank(); ++i) {
235       if (i < dim) {
236         index_broadcast_dims.push_back(i);
237       } else {
238         if (i == dim) {
239           sizes.push_back(input_shape.dimensions(i));
240         }
241         index_broadcast_dims.push_back(i + 1);
242       }
243       sizes.push_back(index_shape.dimensions(i));
244     }
245     auto mask =
246         Eq(BroadcastInDim(index, sizes, index_broadcast_dims),
247            Iota(builder,
248                 ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
249     auto masked_src =
250         Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims),
251                Zeros(builder,
252                      ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
253 
254     return combiner(
255         input,
256         Reduce(masked_src, Zero(builder, input_shape.element_type()),
257                CreateScalarComputation("reducer", input_shape.element_type(),
258                                        builder, combiner),
259                {dim + 1}));
260   });
261 }
262 
TorchIndexSelect(XlaOp input,XlaOp index,int64 dim,int64 batch_dims)263 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
264   XlaBuilder* builder = input.builder();
265   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
266     TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
267     TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
268     if (dim < batch_dims) {
269       return InvalidArgument(
270           "Gather dim must be greater than or equal to the number of batch "
271           "dims");
272     }
273     if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
274         input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
275       index = ConvertElementType(index, U32);
276       index_shape.set_element_type(U32);
277     }
278     std::vector<int64> slice_sizes = SpanToVector(input_shape.dimensions());
279     GatherDimensionNumbers gather_dnums;
280     gather_dnums.set_index_vector_dim(index_shape.rank());
281     if (batch_dims > 0) {
282       ShapeUtil::AppendMajorDimension(1, &index_shape);
283       std::vector<XlaOp> to_concat;
284       to_concat.reserve(batch_dims + 1);
285       for (int64 batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
286         to_concat.push_back(Iota(builder, index_shape, batch_dim));
287       }
288       to_concat.push_back(Reshape(index, index_shape.dimensions()));
289       index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
290     }
291     for (int64 i = 0; i < input_shape.rank(); ++i) {
292       if (i < batch_dims || i == dim) {
293         slice_sizes[i] = std::min<int64>(slice_sizes[i], 1);
294         gather_dnums.add_collapsed_slice_dims(i);
295         gather_dnums.add_start_index_map(i);
296       } else {
297         if (i < dim) {
298           gather_dnums.add_offset_dims(i);
299         } else {
300           gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
301                                        (1 + batch_dims));
302         }
303       }
304     }
305     return Gather(input, index, gather_dnums, slice_sizes);
306   });
307 }
308 
309 }  // namespace xla
310