1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 #include "tensorflow/compiler/xla/client/xla_builder.h"
18
19 namespace xla {
20
SliceInMinorDims(XlaOp x,absl::Span<const int64> start,absl::Span<const int64> end)21 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
22 absl::Span<const int64> end) {
23 XlaBuilder* builder = x.builder();
24 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
25 TF_RET_CHECK(start.size() == end.size());
26 int64 n_minor_dims = start.size();
27
28 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
29
30 const int64 n_dims = shape.rank();
31 TF_RET_CHECK(n_minor_dims <= n_dims);
32 auto major_dims = AsInt64Slice(shape.dimensions())
33 .subspan(
34 /*pos=*/0,
35 /*len=*/n_dims - n_minor_dims);
36
37 // Prepends 0s in the major dim
38 std::vector<int64> padded_start(n_dims, 0);
39 std::copy(start.begin(), start.end(),
40 padded_start.begin() + major_dims.size());
41
42 // Prepends the shape of the major dims.
43 std::vector<int64> padded_end(n_dims);
44 std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
45 std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
46
47 std::vector<int64> strides(n_dims, 1);
48 return Slice(x, padded_start, padded_end, strides);
49 });
50 }
51
UpdateSlice(XlaOp x,XlaOp update,absl::Span<const int64> start)52 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) {
53 XlaBuilder* builder = x.builder();
54 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
55 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
56 const int64 n_dims = shape.rank();
57 TF_RET_CHECK(start.size() == n_dims);
58
59 // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
60 std::vector<int32> start_as_int32(start.begin(), start.end());
61 std::vector<XlaOp> start_ops(start.size());
62 for (int i = 0; i < start.size(); ++i) {
63 start_ops[i] = ConstantR0(builder, start_as_int32[i]);
64 }
65 return DynamicUpdateSlice(x, update, start_ops);
66 });
67 }
68
UpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const int64> start)69 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
70 absl::Span<const int64> start) {
71 XlaBuilder* builder = x.builder();
72 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
73 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
74 const int64 n_dims = shape.rank();
75 const int64 n_minor_dims = start.size();
76 TF_RET_CHECK(n_minor_dims <= n_dims);
77 std::vector<int64> padded_start(n_dims, 0);
78 std::copy(start.begin(), start.end(),
79 padded_start.begin() + (n_dims - n_minor_dims));
80 return UpdateSlice(x, update, padded_start);
81 });
82 }
83
84 namespace {
85
ConcatVectors(absl::Span<const int64> xs,absl::Span<const int64> ys)86 std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
87 absl::Span<const int64> ys) {
88 std::vector<int64> output(xs.size() + ys.size());
89 std::copy(xs.begin(), xs.end(), output.begin());
90 std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
91 return output;
92 }
93
PrependZerosInMajorDims(XlaOp x,absl::Span<const XlaOp> starts)94 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims(
95 XlaOp x, absl::Span<const XlaOp> starts) {
96 XlaBuilder* builder = x.builder();
97 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
98 const int64 n_dims = shape.rank();
99 auto zero = ConstantR0<int32>(builder, 0);
100 std::vector<XlaOp> padded_starts(n_dims, zero);
101 for (int i = 0; i < starts.size(); ++i) {
102 padded_starts[n_dims - starts.size() + i] = starts[i];
103 }
104 return padded_starts;
105 }
106
107 } // namespace
108
DynamicSliceInMinorDims(XlaOp x,absl::Span<const XlaOp> starts,absl::Span<const int64> sizes)109 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
110 absl::Span<const int64> sizes) {
111 XlaBuilder* builder = x.builder();
112 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
113 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
114 const int64 n_dims = shape.rank();
115 int64 n_minor_dims = starts.size();
116 TF_RET_CHECK(n_minor_dims == sizes.size());
117 TF_RET_CHECK(n_minor_dims <= n_dims);
118 auto major_dims = AsInt64Slice(shape.dimensions())
119 .subspan(
120 /*pos=*/0,
121 /*len=*/n_dims - sizes.size());
122 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
123 auto padded_sizes = ConcatVectors(major_dims, sizes);
124 return DynamicSlice(x, padded_starts, padded_sizes);
125 });
126 }
127
DynamicUpdateSliceInMinorDims(XlaOp x,XlaOp update,absl::Span<const XlaOp> starts)128 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
129 absl::Span<const XlaOp> starts) {
130 XlaBuilder* builder = x.builder();
131 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
132 TF_ASSIGN_OR_RETURN(auto padded_starts, PrependZerosInMajorDims(x, starts));
133 return DynamicUpdateSlice(x, update, padded_starts);
134 });
135 }
136
TorchGather(XlaOp input,XlaOp index,int64 dim)137 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) {
138 XlaBuilder* builder = input.builder();
139 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
140 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
141 ShapeUtil::AppendMajorDimension(1, &index_shape);
142 std::vector<XlaOp> to_concat;
143 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
144 to_concat.reserve(input_shape.rank());
145 for (int64 i = 0; i < input_shape.rank(); ++i) {
146 if (i == dim) {
147 to_concat.push_back(Reshape(index, index_shape.dimensions()));
148 } else {
149 to_concat.push_back(Iota(builder, index_shape, i));
150 }
151 }
152 XlaOp gather_indices = ConcatInDim(builder, to_concat, input_shape.rank());
153 std::vector<int64> slice_sizes(input_shape.rank(), 1);
154 GatherDimensionNumbers gather_dnums;
155 gather_dnums.set_index_vector_dim(input_shape.rank());
156 for (int64 i = 0; i < input_shape.rank(); ++i) {
157 gather_dnums.add_collapsed_slice_dims(i);
158 gather_dnums.add_start_index_map(i);
159 }
160 return Gather(input, gather_indices, gather_dnums, slice_sizes);
161 });
162 }
163
TorchIndexSelect(XlaOp input,XlaOp index,int64 dim)164 XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64 dim) {
165 XlaBuilder* builder = input.builder();
166 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
167 TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
168 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index));
169 std::vector<int64> slice_sizes = input_shape.dimensions();
170 slice_sizes[dim] = 1;
171 GatherDimensionNumbers gather_dnums;
172 for (int64 i = 0; i < input_shape.rank(); ++i) {
173 if (i != dim) {
174 gather_dnums.add_offset_dims(i);
175 }
176 }
177 gather_dnums.set_index_vector_dim(index_shape.rank());
178 gather_dnums.add_collapsed_slice_dims(dim);
179 gather_dnums.add_start_index_map(dim);
180 return Gather(input, index, gather_dnums, slice_sizes);
181 });
182 }
183
184 } // namespace xla
185