1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17
18 #include <algorithm>
19 #include <limits>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/util.h"
26
27 namespace xla {
28
DynamicStridedSlice(XlaOp input,absl::Span<const XlaOp> base_indices,absl::Span<const int64> window_sizes,absl::Span<const int64> strides)29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
30 absl::Span<const int64> window_sizes,
31 absl::Span<const int64> strides) {
32 XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes);
33 if (std::any_of(strides.begin(), strides.end(),
34 [](int64 stride) { return stride != 1; })) {
35 sliced_input = Slice(sliced_input, std::vector<int64>(window_sizes.size()),
36 window_sizes, strides);
37 }
38 return sliced_input;
39 }
40
SliceInMinorDims(XlaOp x,absl::Span<const int64> start,absl::Span<const int64> end)41 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
42 absl::Span<const int64> end) {
43 XlaBuilder* builder = x.builder();
44 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
45 TF_RET_CHECK(start.size() == end.size());
46 int64 n_minor_dims = start.size();
47
48 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
49
50 const int64 n_dims = shape.rank();
51 TF_RET_CHECK(n_minor_dims <= n_dims);
52 auto major_dims = AsInt64Slice(shape.dimensions())
53 .subspan(
54 /*pos=*/0,
55 /*len=*/n_dims - n_minor_dims);
56
57 // Prepends 0s in the major dim
58 std::vector<int64> padded_start(n_dims, 0);
59 std::copy(start.begin(), start.end(),
60 padded_start.begin() + major_dims.size());
61
62 // Prepends the shape of the major dims.
63 std::vector<int64> padded_end(n_dims);
64 std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
65 std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
66
67 std::vector<int64> strides(n_dims, 1);
68 return Slice(x, padded_start, padded_end, strides);
69 });
70 }
71
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64> start)72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
73 XlaBuilder* builder = x.builder();
74 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
75 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
76 const int64 n_dims = shape.rank();
77 const int64 start_size = start.size();
78 TF_RET_CHECK(start_size == n_dims);
79
80 // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
81 std::vector<int32> start_as_int32(start.begin(), start.end());
82 std::vector<XlaOp> start_ops(start.size());
83 for (int i = 0, end = start.size(); i < end; ++i) {
84 start_ops[i] = ConstantR0(builder, start_as_int32[i]);
85 }
86 return DynamicUpdateSlice(x, update, start_ops);
87 });
88 }
89
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64> start)90 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
91 absl::Span<const int64> start) {
92 XlaBuilder* builder = x.builder();
93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
95 const int64 n_dims = shape.rank();
96 const int64 n_minor_dims = start.size();
97 TF_RET_CHECK(n_minor_dims <= n_dims);
98 std::vector<int64> padded_start(n_dims, 0);
99 std::copy(start.begin(), start.end(),
100 padded_start.begin() + (n_dims - n_minor_dims));
101 return UpdateSlice(x, update, padded_start);
102 });
103 }
104
105 namespace {
106
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)107 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
108 absl::Span<const int64> ys) {
109 std::vector<int64> output(xs.size() + ys.size());
110 std::copy(xs.begin(), xs.end(), output.begin());
111 std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
112 return output;
113 }
114
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)115 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
116 XlaOp x, absl::Span<const XlaOp> starts) {
117 XlaBuilder* builder = x.builder();
118 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
119 const int64 n_dims = shape.rank();
120 auto zero = ConstantR0<int32>(builder, 0);
121 std::vector<XlaOp> padded_starts(n_dims, zero);
122 for (int i = 0; i < starts.size(); ++i) {
123 padded_starts[n_dims - starts.size() + i] = starts[i];
124 }
125 return padded_starts;
126 }
127
128 } // namespace
129
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64> sizes)130 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
131 absl::Span<const int64> sizes) {
132 XlaBuilder* builder = x.builder();
133 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
134 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
135 const int64 n_dims = shape.rank();
136 int64 n_minor_dims = starts.size();
137 TF_RET_CHECK(n_minor_dims == sizes.size());
138 TF_RET_CHECK(n_minor_dims <= n_dims);
139 auto major_dims = AsInt64Slice(shape.dimensions())
140 .subspan(
141 /*pos=*/0,
142 /*len=*/n_dims - sizes.size());
143 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
144 auto padded_sizes = ConcatVectors(major_dims, sizes);
145 return DynamicSlice(x, padded_starts, padded_sizes);
146 });
147 }
148
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)149 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
150 absl::Span<const XlaOp> starts) {
151 XlaBuilder* builder = x.builder();
152 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
153 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
154 return DynamicUpdateSlice(x, update, padded_starts);
155 });
156 }
157
TorchGather(XlaOp input,XlaOp index,int64 dim,bool sparse)158 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) {
159 XlaBuilder* builder = input.builder();
160 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
161 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
162 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
163 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
164 input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
165 index = ConvertElementType(index, U32);
166 index_shape.set_element_type(U32);
167 }
168 if (index_shape.rank() == 1) {
169 return TorchIndexSelect(input, index, 0);
170 }
171 if (!sparse) {
172 std::vector<int64> index_broadcast_dims;
173 std::vector<int64> input_broadcast_dims;
174 std::vector<int64> sizes;
175 for (int64 i = 0; i < index_shape.rank(); ++i) {
176 if (i < dim) {
177 input_broadcast_dims.push_back(i);
178 index_broadcast_dims.push_back(i);
179 } else if (i == dim) {
180 sizes.push_back(input_shape.dimensions(i));
181 input_broadcast_dims.push_back(i);
182 index_broadcast_dims.push_back(i + 1);
183 } else {
184 input_broadcast_dims.push_back(i + 1);
185 index_broadcast_dims.push_back(i + 1);
186 }
187 sizes.push_back(index_shape.dimensions(i));
188 }
189 auto mask = Eq(
190 BroadcastInDim(index, sizes, index_broadcast_dims),
191 Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes),
192 dim));
193 auto masked_input = Select(
194 mask, BroadcastInDim(input, sizes, input_broadcast_dims),
195 Zeros(builder,
196 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
197 return Reduce(masked_input, Zero(builder, input_shape.element_type()),
198 CreateScalarIdentityWithZeroComputation(
199 input_shape.element_type(), builder),
200 {dim});
201 }
202
203 ShapeUtil::AppendMajorDimension(1, &index_shape);
204 std::vector<XlaOp> to_concat;
205
206 to_concat.reserve(input_shape.rank());
207 for (int64 i = 0; i < input_shape.rank(); ++i) {
208 if (i == dim) {
209 to_concat.push_back(Reshape(index, index_shape.dimensions()));
210 } else {
211 to_concat.push_back(Iota(builder, index_shape, i));
212 }
213 }
214 XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
215 std::vector<int64> slice_sizes(input_shape.rank(), 1);
216 GatherDimensionNumbers gather_dnums;
217 gather_dnums.set_index_vector_dim(input_shape.rank());
218 for (int64 i = 0; i < input_shape.rank(); ++i) {
219 gather_dnums.add_collapsed_slice_dims(i);
220 gather_dnums.add_start_index_map(i);
221 }
222 return Gather(input, gather_indices, gather_dnums, slice_sizes);
223 });
224 }
225
TorchScatterDense(XlaOp input,XlaOp index,XlaOp src,int64 dim,const std::function<XlaOp (XlaOp,XlaOp)> & combiner)226 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
227 const std::function<XlaOp(XlaOp, XlaOp)>& combiner) {
228 XlaBuilder* builder = input.builder();
229 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
230 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
231 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
232 std::vector<int64> index_broadcast_dims;
233 std::vector<int64> sizes;
234 for (int64 i = 0; i < index_shape.rank(); ++i) {
235 if (i < dim) {
236 index_broadcast_dims.push_back(i);
237 } else {
238 if (i == dim) {
239 sizes.push_back(input_shape.dimensions(i));
240 }
241 index_broadcast_dims.push_back(i + 1);
242 }
243 sizes.push_back(index_shape.dimensions(i));
244 }
245 auto mask =
246 Eq(BroadcastInDim(index, sizes, index_broadcast_dims),
247 Iota(builder,
248 ShapeUtil::MakeShape(index_shape.element_type(), sizes), dim));
249 auto masked_src =
250 Select(mask, BroadcastInDim(src, sizes, index_broadcast_dims),
251 Zeros(builder,
252 ShapeUtil::MakeShape(input_shape.element_type(), sizes)));
253
254 return combiner(
255 input,
256 Reduce(masked_src, Zero(builder, input_shape.element_type()),
257 CreateScalarComputation("reducer", input_shape.element_type(),
258 builder, combiner),
259 {dim + 1}));
260 });
261 }
262
TorchIndexSelect(XlaOp input,XlaOp index,int64 dim,int64 batch_dims)263 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim, int64 batch_dims) {
264 XlaBuilder* builder = input.builder();
265 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
266 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
267 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
268 if (dim < batch_dims) {
269 return InvalidArgument(
270 "Gather dim must be greater than or equal to the number of batch "
271 "dims");
272 }
273 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) &&
274 input_shape.dimensions(dim) < std::numeric_limits<uint32>::max()) {
275 index = ConvertElementType(index, U32);
276 index_shape.set_element_type(U32);
277 }
278 std::vector<int64> slice_sizes = SpanToVector(input_shape.dimensions());
279 GatherDimensionNumbers gather_dnums;
280 gather_dnums.set_index_vector_dim(index_shape.rank());
281 if (batch_dims > 0) {
282 ShapeUtil::AppendMajorDimension(1, &index_shape);
283 std::vector<XlaOp> to_concat;
284 to_concat.reserve(batch_dims + 1);
285 for (int64 batch_dim = 0; batch_dim < batch_dims; ++batch_dim) {
286 to_concat.push_back(Iota(builder, index_shape, batch_dim));
287 }
288 to_concat.push_back(Reshape(index, index_shape.dimensions()));
289 index = ConcatInDim(builder, to_concat, gather_dnums.index_vector_dim());
290 }
291 for (int64 i = 0; i < input_shape.rank(); ++i) {
292 if (i < batch_dims || i == dim) {
293 slice_sizes[i] = std::min<int64>(slice_sizes[i], 1);
294 gather_dnums.add_collapsed_slice_dims(i);
295 gather_dnums.add_start_index_map(i);
296 } else {
297 if (i < dim) {
298 gather_dnums.add_offset_dims(i);
299 } else {
300 gather_dnums.add_offset_dims(i + gather_dnums.index_vector_dim() -
301 (1 + batch_dims));
302 }
303 }
304 }
305 return Gather(input, index, gather_dnums, slice_sizes);
306 });
307 }
308
309 } // namespace xla
310