1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
DynamicStridedSlice(XlaOp input,absl::Span<const XlaOp> base_indices,absl::Span<const int64> window_sizes,absl::Span<const int64> strides)29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
30 absl::Span<const int64> window_sizes,
31 absl::Span<const int64> strides) {
32 XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
33 if (std::any_of(strides.begin(), strides.end(),
34 [](int64_t stride) { return stride != 1; })) {
35 sliced_input = Slice(sliced_input, std::vector<int64>(window_sizes.size()),
36 window_sizes, strides);
37 }
38 return sliced_input;
39 }
40
SliceInMinorDims(XlaOp x,absl::Span<const int64> start,absl::Span<const int64> end)41 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
42 absl::Span<const int64> end) {
43 XlaBuilder* builder = x.builder();
44 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
45 TF_RET_CHECK(start.size() == end.size());
46 int64_t n_minor_dims = start.size();
47
48 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
49
50 const int64_t n_dims = shape.rank();
51 TF_RET_CHECK(n_minor_dims <= n_dims);
52 auto major_dims = AsInt64Slice(shape.dimensions())
53 .subspan(
54 /*pos=*/0,
55 /*len=*/n_dims - n_minor_dims);
56
57 // Prepends 0s in the major dim
58 std::vector<int64> padded_start(n_dims, 0);
59 std::copy(start.begin(), start.end(),
60 padded_start.begin() + major_dims.size());
61
62 // Prepends the shape of the major dims.
63 std::vector<int64> padded_end(n_dims);
64 std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
65 std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
66
67 std::vector<int64> strides(n_dims, 1);
68 return Slice(x, padded_start, padded_end, strides);
69 });
70 }
71
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64> start)72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
73 XlaBuilder* builder = x.builder();
74 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
76 const int64_t n_dims = shape.rank();
77 const int64_t start_size = start.size();
78 TF_RET_CHECK(start_size == n_dims);
79
80 // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
81 std::vector<int32> start_as_int32(start.begin(), start.end());
82 std::vector<XlaOp> start_ops(start.size());
83 for (int i = 0, end = start.size(); i < end; ++i) {
84 start_ops[i] = ConstantR0(builder, start_as_int32[i]);
85 }
86 return DynamicUpdateSlice(x, update, start_ops);
87 });
88 }
89
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64> start)90 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
91 absl::Span<const int64> start) {
92 XlaBuilder* builder = x.builder();
93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
95 const int64_t n_dims = shape.rank();
96 const int64_t n_minor_dims = start.size();
97 TF_RET_CHECK(n_minor_dims <= n_dims);
98 std::vector<int64> padded_start(n_dims, 0);
99 std::copy(start.begin(), start.end(),
100 padded_start.begin() + (n_dims - n_minor_dims));
101 return UpdateSlice(x, update, padded_start);
102 });
103 }
104
105 namespace {
106
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)107 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
108 absl::Span<const int64> ys) {
109 std::vector<int64> output(xs.size() + ys.size());
110 std::copy(xs.begin(), xs.end(), output.begin());
111 std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
112 return output;
113 }
114
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)115 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
116 XlaOp x, absl::Span<const XlaOp> starts) {
117 XlaBuilder* builder = x.builder();
118 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
119 const int64_t n_dims = shape.rank();
120 auto zero = ConstantR0<int32>(builder, 0);
121 std::vector<XlaOp> padded_starts(n_dims, zero);
122 for (int i = 0; i < starts.size(); ++i) {
123 padded_starts[n_dims - starts.size() + i] = starts[i];
124 }
125 return padded_starts;
126 }
127
128 } // namespace
129
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64> sizes)130 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
131 absl::Span<const int64> sizes) {
132 XlaBuilder* builder = x.builder();
133 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
134 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
135 const int64_t n_dims = shape.rank();
136 int64_t n_minor_dims = starts.size();
137 TF_RET_CHECK(n_minor_dims == sizes.size());
138 TF_RET_CHECK(n_minor_dims <= n_dims);
139 auto major_dims = AsInt64Slice(shape.dimensions())
140 .subspan(
141 /*pos=*/0,
142 /*len=*/n_dims - sizes.size());
143 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
144 auto padded_sizes = ConcatVectors(major_dims, sizes);
145 return DynamicSlice(x, padded_starts, padded_sizes);
146 });
147 }
148
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)149 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
150 absl::Span<const XlaOp> starts) {
151 XlaBuilder* builder = x.builder();
152 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
153 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
154 return DynamicUpdateSlice(x, update, padded_starts);
155 });
156 }
157
TorchGather(XlaOp input,XlaOp index,int64_t dim,bool sparse)158 XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse) {
159 XlaBuilder* builder = input.builder();
160 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
161 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
162 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
163 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
164 input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
165 index = ConvertElementType(index, U32);
166 index_shape.set_element_type(U32);
167 }
168 if (index_shape.rank() == 1) {
169 return TorchIndexSelect(input, index, 0);
170 }
171 if (!sparse) {
172 std::vector<int64> index_broadcast_dims;
173 std::vector<int64> input_broadcast_dims;
174 std::vector<int64> sizes;
175 for (int64_t i = 0; i < index_shape.rank(); ++i) {
176 if (i < dim) {
177 input_broadcast_dims.push_back(i);
178 index_broadcast_dims.push_back(i);
179 } else if (i == dim) {
180 sizes.push_back(input_shape.dimensions(i));
181 input_broadcast_dims.push_back(i);
182 index_broadcast_dims.push_back(i + 1);
183 } else {
184 input_broadcast_dims.push_back(i + 1);
185 index_broadcast_dims.push_back(i + 1);
186 }
187 sizes.push_back(index_shape.dimensions(i));
188 }
189 auto mask = Eq(
190 BroadcastInDim(index, sizes, index_broadcast_dims),
191 Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
192 dim));
193 auto masked_input = Select(
194 mask, BroadcastInDim(input, sizes, input_broadcast_dims),
195 Zeros(builder,
196 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
197 return Reduce(masked_input, Zero(builder, input_shape.element_type()),
198 CreateScalarIdentityWithZeroComputation(
199 input_shape.element_type(), builder),
200 {dim});
201 }
202
203 ShapeUtil::AppendMajorDimension(1, &index_shape);
204 std::vector<XlaOp> to_concat;
205
206 to_concat.reserve(input_shape.rank());
207 for (int64_t i = 0; i < input_shape.rank(); ++i) {
208 if (i == dim) {
209 to_concat.push_back(Reshape(index, index_shape.dimensions()));
210 } else {
211 to_concat.push_back(Iota(builder, index_shape, i));
212 }
213 }
214 XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
215 std::vector<int64> slice_sizes(input_shape.rank(), 1);
216 GatherDimensionNumbers gather_dnums;
217 gather_dnums.set_index_vector_dim(input_shape.rank());
218 for (int64_t i = 0; i < input_shape.rank(); ++i) {
219 gather_dnums.add_collapsed_slice_dims(i);
220 gather_dnums.add_start_index_map(i);
221 }
222 return Gather(input, gather_indices, gather_dnums, slice_sizes);
223 });
224 }
225
TorchScatterDense(XlaOp input,XlaOp index,XlaOp src,int64_t dim,const std::function<XlaOp (XlaOp,XlaOp)> & combiner)226 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim,
227 const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
228 XlaBuilder* builder = input.builder();
229 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
230 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
231 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
232 std::vector<int64> index_broadcast_dims;
233 std::vector<int64> sizes;
234 for (int64_t i = 0; i < index_shape.rank(); ++i) {
235 if (i < dim) {
236 index_broadcast_dims.push_back(i);
237 } else {
238 if (i == dim) {
239 sizes.push_back(input_shape.dimensions(i));
240 }
241 index_broadcast_dims.push_back(i + 1);
242 }
243 sizes.push_back(index_shape.dimensions(i));
244 }
245 auto mask =
246 Eq(BroadcastInDim(index, sizes, index_broadcast_dims),
247 Iota(builder,
248 ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
249 auto masked_src =
250 Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims),
251 Zeros(builder,
252 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
253
254 return combiner(
255 input,
256 Reduce(masked_src, Zero(builder, input_shape.element_type()),
257 CreateScalarComputation("reducer", input_shape.element_type(),
258 builder, combiner),
259 {dim + 1}));
260 });
261 }
262
TorchIndexSelect(XlaOp input,XlaOp index,int64_t dim,int64_t batch_dims)263 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim,
264 int64_t batch_dims) {
265 XlaBuilder* builder = input.builder();
266 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
267 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
268 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
269 if (dim < batch_dims) {
270 return InvalidArgument(
271 "Gather dim must be greater than or equal to the number of batch "
272 "dims");
273 }
274 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
275 input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
276 index = ConvertElementType(index, U32);
277 index_shape.set_element_type(U32);
278 }
279 std::vector<int64> slice_sizes = SpanToVector(input_shape.dimensions());
280 GatherDimensionNumbers gather_dnums;
281 gather_dnums.set_index_vector_dim(index_shape.rank());
282 if (batch_dims > 0) {
283 ShapeUtil::AppendMajorDimension(1, &index_shape);
284 std::vector<XlaOp> to_concat;
285 to_concat.reserve(batch_dims + 1);
286 for (int64_t batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
287 to_concat.push_back(Iota(builder, index_shape, batch_dim));
288 }
289 to_concat.push_back(Reshape(index, index_shape.dimensions()));
290 index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
291 }
292 for (int64_t i = 0; i < input_shape.rank(); ++i) {
293 if (i < batch_dims || i == dim) {
294 slice_sizes[i] = std::min<int64>(slice_sizes[i], 1);
295 gather_dnums.add_collapsed_slice_dims(i);
296 gather_dnums.add_start_index_map(i);
297 } else {
298 if (i < dim) {
299 gather_dnums.add_offset_dims(i);
300 } else {
301 gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
302 (1 + batch_dims));
303 }
304 }
305 }
306 return Gather(input, index, gather_dnums, slice_sizes);
307 });
308 }
309
310 } // namespace xla
311