1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17
18 #include <array>
19 #include <numeric>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/lib/slicing.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38
39 namespace xla {
40
IdentityMatrix(XlaBuilder * builder,PrimitiveType type,int64 m,int64 n)41 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
42 int64 n) {
43 auto a = Iota(builder, U32, m);
44 auto b = Iota(builder, U32, n);
45 auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
46 return ConvertElementType(indicator, type);
47 }
48
GetMatrixDiagonal(XlaOp x,int k)49 XlaOp GetMatrixDiagonal(XlaOp x, int k) {
50 XlaBuilder* builder = x.builder();
51 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
52 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
53 const int64 n_dims = shape.rank();
54 TF_RET_CHECK(n_dims >= 2);
55 const int64 m = shape.dimensions(n_dims - 2);
56 const int64 n = shape.dimensions(n_dims - 1);
57
58 auto offset = ConstantR0WithType(builder, S32, k);
59
60 absl::Span<const int64> major_dims =
61 AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
62 auto a = Iota(builder, S32, n);
63 auto b = Iota(builder, S32, m) + offset;
64 auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
65 auto mask = Broadcast(indicator, major_dims);
66
67 // TPUs don't support S64 add reduction at the moment. But fortunately
68 // OR-reductions work just as well for integers.
69 XlaComputation reducer =
70 primitive_util::IsIntegralType(shape.element_type())
71 ? CreateScalarOrComputation(shape.element_type(), builder)
72 : CreateScalarAddComputation(shape.element_type(), builder);
73 // k == 0, we can save one slice op.
74 if (k == 0) {
75 return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
76 reducer, {m >= n ? n_dims - 2 : n_dims - 1});
77 } else if (k > 0) {
78 auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
79 ScalarLike(x, 0), reducer, {n_dims - 2});
80 return SliceInMinorDims(result, {std::min<int64>(k, n)},
81 {std::min(m + k, n)});
82 } else {
83 auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
84 ScalarLike(x, 0), reducer, {n_dims - 1});
85 return SliceInMinorDims(result, {std::min<int64>(-k, m)},
86 {std::min(m, n - k)});
87 }
88 });
89 }
90
TriangleMask(XlaOp x,int diagonal)91 XlaOp TriangleMask(XlaOp x, int diagonal) {
92 XlaBuilder* builder = x.builder();
93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
95 const int64 n_dims = shape.rank();
96 TF_RET_CHECK(n_dims >= 2);
97 const int64 m = shape.dimensions(n_dims - 2);
98 const int64 n = shape.dimensions(n_dims - 1);
99 absl::Span<const int64> major_dims =
100 AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
101 auto a = Iota(builder, S32, n);
102 auto b = Iota(builder, S32, m) + ConstantR0<int32>(builder, diagonal);
103 XlaOp indicator;
104 indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
105 return Broadcast(indicator, major_dims);
106 });
107 }
108
Triangle(XlaOp x,bool lower)109 XlaOp Triangle(XlaOp x, bool lower) {
110 return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x))
111 : Select(TriangleMask(x, -1), ZerosLike(x), x);
112 }
113
UpperTriangle(XlaOp x)114 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
115
LowerTriangle(XlaOp x)116 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
117
ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,absl::Span<const int64> y_config,absl::Span<const int64> output_config)118 Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
119 absl::Span<const int64> y_config,
120 absl::Span<const int64> output_config) {
121 for (auto dim : output_config) {
122 if (absl::c_linear_search(x_config, dim) ||
123 absl::c_linear_search(y_config, dim)) {
124 if (absl::c_count(output_config, dim) > 1) {
125 return InvalidArgument("Einsum has repeated output dimension.");
126 }
127 continue;
128 }
129 return InvalidArgument(
130 "Einsum has output dimension without corresponding input dimension.");
131 }
132 for (auto dim : x_config) {
133 if (absl::c_linear_search(y_config, dim) ||
134 absl::c_linear_search(output_config, dim)) {
135 if (absl::c_count(x_config, dim) > 1) {
136 return InvalidArgument("Einsum has repeated lhs dimension.");
137 }
138 continue;
139 }
140 return InvalidArgument(
141 "Einsum has lhs dimension without corresponding rhs or output "
142 "dimension.");
143 }
144 for (auto dim : y_config) {
145 if (absl::c_linear_search(x_config, dim) ||
146 absl::c_linear_search(output_config, dim)) {
147 if (absl::c_count(y_config, dim) > 1) {
148 return InvalidArgument("Einsum has repeated rhs dimension.");
149 }
150 continue;
151 }
152 return InvalidArgument(
153 "Einsum has rhs dimension without corresponding lhs or output "
154 "dimension.");
155 }
156 return Status::OK();
157 }
158
Einsum(xla::XlaOp x,absl::Span<const int64> x_config,xla::XlaOp y,absl::Span<const int64> y_config,absl::Span<const int64> output_config,xla::PrecisionConfig::Precision precision)159 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
160 absl::Span<const int64> y_config,
161 absl::Span<const int64> output_config,
162 xla::PrecisionConfig::Precision precision) {
163 XlaBuilder* builder = x.builder();
164 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
165 TF_RETURN_IF_ERROR(
166 ValidateEinsumNumericDimensions(x_config, y_config, output_config));
167 const int64 x_rank = x_config.size();
168 const int64 y_rank = y_config.size();
169 const int64 output_rank = output_config.size();
170 absl::flat_hash_set<int64> x_map;
171 absl::flat_hash_set<int64> y_map;
172 absl::flat_hash_set<int64> output_map;
173
174 auto find = [&](const absl::flat_hash_set<int64>& map, int64 d) {
175 return map.count(d) != 0;
176 };
177
178 auto insert = [&](absl::flat_hash_set<int64>& map, char d) {
179 CHECK(!find(map, d));
180 map.insert(d);
181 };
182
183 for (auto d : x_config) {
184 insert(x_map, d);
185 }
186
187 for (auto d : y_config) {
188 insert(y_map, d);
189 }
190
191 for (auto d : output_config) {
192 insert(output_map, d);
193 }
194
195 DotDimensionNumbers dnums;
196 std::vector<int64> lhs_outer_dims;
197 auto is_batch_dim = [&](int64 d) {
198 return find(x_map, d) && find(y_map, d) && find(output_map, d);
199 };
200 auto is_contracting = [&](int64 d) {
201 return find(x_map, d) && find(y_map, d);
202 };
203 auto rhs_dimension_number = [&](int64 d) {
204 return absl::c_find(y_config, d) - y_config.begin();
205 };
206 for (int64 i = 0; i < x_rank; ++i) {
207 auto dim_name = x_config[i];
208 if (is_batch_dim(dim_name)) {
209 dnums.add_lhs_batch_dimensions(i);
210 dnums.add_rhs_batch_dimensions(rhs_dimension_number(dim_name));
211 } else if (is_contracting(dim_name)) {
212 dnums.add_lhs_contracting_dimensions(i);
213 dnums.add_rhs_contracting_dimensions(rhs_dimension_number(dim_name));
214 } else {
215 lhs_outer_dims.push_back(i);
216 }
217 }
218
219 std::vector<int64> rhs_outer_dims;
220 for (int64 i = 0; i < y_rank; ++i) {
221 auto dim_name = y_config[i];
222 if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) {
223 rhs_outer_dims.push_back(i);
224 }
225 }
226
227 auto output_dimension_number = [&](char d) {
228 return absl::c_find(output_config, d) - output_config.begin();
229 };
230
231 std::vector<int64> output_dims;
232 output_dims.reserve(output_rank);
233 for (auto d : dnums.lhs_batch_dimensions()) {
234 output_dims.push_back(output_dimension_number(x_config[d]));
235 }
236 for (auto d : lhs_outer_dims) {
237 output_dims.push_back(output_dimension_number(x_config[d]));
238 }
239 for (auto d : rhs_outer_dims) {
240 output_dims.push_back(output_dimension_number(y_config[d]));
241 }
242
243 std::vector<int64> transpose_dims(output_rank);
244 for (int64 i = 0; i < output_rank; ++i) {
245 transpose_dims[output_dims[i]] = i;
246 }
247
248 PrecisionConfig precision_proto;
249 precision_proto.add_operand_precision(precision);
250 precision_proto.add_operand_precision(precision);
251 return Transpose(DotGeneral(x, y, dnums, &precision_proto), transpose_dims);
252 });
253 }
254
BatchDot(XlaOp x,XlaOp y,PrecisionConfig::Precision precision)255 XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
256 XlaBuilder* builder = x.builder();
257 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
258 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
259 TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
260
261 // The batch dimensions must be equal and the matrix dimensions must be
262 // valid.
263 std::vector<int64> batch_dimension_numbers;
264 const int ndims = x_shape.rank();
265 batch_dimension_numbers.reserve(ndims - 2);
266 for (int i = 0; i < ndims - 2; ++i) {
267 batch_dimension_numbers.push_back(i);
268 }
269 std::vector<int64> x_config = batch_dimension_numbers;
270 x_config.push_back(ndims - 2);
271 x_config.push_back(ndims);
272 std::vector<int64> y_config = batch_dimension_numbers;
273 y_config.push_back(ndims);
274 y_config.push_back(ndims - 1);
275 std::vector<int64> output_config = batch_dimension_numbers;
276 output_config.push_back(ndims - 2);
277 output_config.push_back(ndims - 1);
278 return Einsum(x, x_config, y, y_config, output_config, precision);
279 });
280 }
281
ParseEinsumString(absl::string_view einsum_config)282 StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
283 absl::string_view einsum_config) {
284 std::array<std::vector<int64>, 3> einsum_config_numeric;
285 std::vector<absl::string_view> main_split =
286 absl::StrSplit(einsum_config, ',');
287
288 if (main_split.size() != 2) {
289 return InvalidArgument("Expected one \",\" in einsum_config.");
290 }
291
292 auto maybe_invalid_character = [](char d) {
293 if (absl::ascii_isalpha(d)) {
294 return Status::OK();
295 }
296 if (d == '.') {
297 return InvalidArgument("Unsupported \"...\" or \".\" in einsum config.");
298 }
299 return InvalidArgument("Unexpected character in einsum config.");
300 };
301
302 auto& x_config = einsum_config_numeric[0];
303 x_config.reserve(main_split[0].size());
304 for (auto d : main_split[0]) {
305 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
306 x_config.push_back(static_cast<int64>(d));
307 }
308 std::vector<absl::string_view> y_output_split =
309 absl::StrSplit(main_split[1], "->");
310 if (y_output_split.size() != 2) {
311 return InvalidArgument("Expected one \"->\" in einsum_config.");
312 }
313 auto& y_config = einsum_config_numeric[1];
314 y_config.reserve(y_output_split[0].size());
315 for (auto d : y_output_split[0]) {
316 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
317 y_config.push_back(static_cast<int64>(d));
318 }
319 auto& output_config = einsum_config_numeric[2];
320 output_config.reserve(y_output_split[1].size());
321 for (auto d : y_output_split[1]) {
322 TF_RETURN_IF_ERROR(maybe_invalid_character(d));
323 output_config.push_back(static_cast<int64>(d));
324 }
325 return einsum_config_numeric;
326 }
327
Einsum(XlaOp x,XlaOp y,absl::string_view einsum_config,PrecisionConfig::Precision precision)328 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
329 PrecisionConfig::Precision precision) {
330 XlaBuilder* builder = x.builder();
331 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
332 TF_ASSIGN_OR_RETURN(auto einsum_config_numeric,
333 ParseEinsumString(einsum_config));
334 return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
335 einsum_config_numeric[2], precision);
336 });
337 }
338
TransposeInMinorDims(XlaOp x)339 XlaOp TransposeInMinorDims(XlaOp x) {
340 XlaBuilder* builder = x.builder();
341 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
342 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
343 const int64 n_dims = shape.rank();
344 TF_RET_CHECK(n_dims >= 2);
345 std::vector<int64> permutation(n_dims);
346 std::iota(permutation.begin(), permutation.end(), 0);
347 std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
348 return Transpose(x, permutation);
349 });
350 }
351
MaybeTransposeInMinorDims(XlaOp x,bool transpose)352 XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
353 return transpose ? TransposeInMinorDims(x) : x;
354 }
355
356 } // namespace xla
357