• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/matrix.h"
17 
18 #include <array>
19 #include <numeric>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/ascii.h"
25 #include "absl/strings/str_split.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/lib/slicing.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 
39 namespace xla {
40 
IdentityMatrix(XlaBuilder * builder,PrimitiveType type,int64 m,int64 n)41 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
42                      int64 n) {
43   auto a = Iota(builder, U32, m);
44   auto b = Iota(builder, U32, n);
45   auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
46   return ConvertElementType(indicator, type);
47 }
48 
GetMatrixDiagonal(XlaOp x,int k)49 XlaOp GetMatrixDiagonal(XlaOp x, int k) {
50   XlaBuilder* builder = x.builder();
51   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
52     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
53     const int64 n_dims = shape.rank();
54     TF_RET_CHECK(n_dims >= 2);
55     const int64 m = shape.dimensions(n_dims - 2);
56     const int64 n = shape.dimensions(n_dims - 1);
57 
58     auto offset = ConstantR0WithType(builder, S32, k);
59 
60     absl::Span<const int64> major_dims =
61         AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
62     auto a = Iota(builder, S32, n);
63     auto b = Iota(builder, S32, m) + offset;
64     auto indicator = Eq(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
65     auto mask = Broadcast(indicator, major_dims);
66 
67     // TPUs don't support S64 add reduction at the moment. But fortunately
68     // OR-reductions work just as well for integers.
69     XlaComputation reducer =
70         primitive_util::IsIntegralType(shape.element_type())
71             ? CreateScalarOrComputation(shape.element_type(), builder)
72             : CreateScalarAddComputation(shape.element_type(), builder);
73     // k == 0, we can save one slice op.
74     if (k == 0) {
75       return Reduce(Select(mask, x, Zeros(builder, shape)), ScalarLike(x, 0),
76                     reducer, {m >= n ? n_dims - 2 : n_dims - 1});
77     } else if (k > 0) {
78       auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
79                            ScalarLike(x, 0), reducer, {n_dims - 2});
80       return SliceInMinorDims(result, {std::min<int64>(k, n)},
81                               {std::min(m + k, n)});
82     } else {
83       auto result = Reduce(Select(mask, x, Zeros(builder, shape)),
84                            ScalarLike(x, 0), reducer, {n_dims - 1});
85       return SliceInMinorDims(result, {std::min<int64>(-k, m)},
86                               {std::min(m, n - k)});
87     }
88   });
89 }
90 
TriangleMask(XlaOp x,int diagonal)91 XlaOp TriangleMask(XlaOp x, int diagonal) {
92   XlaBuilder* builder = x.builder();
93   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
95     const int64 n_dims = shape.rank();
96     TF_RET_CHECK(n_dims >= 2);
97     const int64 m = shape.dimensions(n_dims - 2);
98     const int64 n = shape.dimensions(n_dims - 1);
99     absl::Span<const int64> major_dims =
100         AsInt64Slice(shape.dimensions()).subspan(/*pos=*/0, /*len=*/n_dims - 2);
101     auto a = Iota(builder, S32, n);
102     auto b = Iota(builder, S32, m) + ConstantR0<int32>(builder, diagonal);
103     XlaOp indicator;
104     indicator = Ge(b, Broadcast(a, {m}), /*broadcast_dimensions=*/{0});
105     return Broadcast(indicator, major_dims);
106   });
107 }
108 
Triangle(XlaOp x,bool lower)109 XlaOp Triangle(XlaOp x, bool lower) {
110   return lower ? Select(TriangleMask(x, 0), x, ZerosLike(x))
111                : Select(TriangleMask(x, -1), ZerosLike(x), x);
112 }
113 
UpperTriangle(XlaOp x)114 XlaOp UpperTriangle(XlaOp x) { return Triangle(x, false); }
115 
LowerTriangle(XlaOp x)116 XlaOp LowerTriangle(XlaOp x) { return Triangle(x, true); }
117 
ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,absl::Span<const int64> y_config,absl::Span<const int64> output_config)118 Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
119                                        absl::Span<const int64> y_config,
120                                        absl::Span<const int64> output_config) {
121   for (auto dim : output_config) {
122     if (absl::c_linear_search(x_config, dim) ||
123         absl::c_linear_search(y_config, dim)) {
124       if (absl::c_count(output_config, dim) > 1) {
125         return InvalidArgument("Einsum has repeated output dimension.");
126       }
127       continue;
128     }
129     return InvalidArgument(
130         "Einsum has output dimension without corresponding input dimension.");
131   }
132   for (auto dim : x_config) {
133     if (absl::c_linear_search(y_config, dim) ||
134         absl::c_linear_search(output_config, dim)) {
135       if (absl::c_count(x_config, dim) > 1) {
136         return InvalidArgument("Einsum has repeated lhs dimension.");
137       }
138       continue;
139     }
140     return InvalidArgument(
141         "Einsum has lhs dimension without corresponding rhs or output "
142         "dimension.");
143   }
144   for (auto dim : y_config) {
145     if (absl::c_linear_search(x_config, dim) ||
146         absl::c_linear_search(output_config, dim)) {
147       if (absl::c_count(y_config, dim) > 1) {
148         return InvalidArgument("Einsum has repeated rhs dimension.");
149       }
150       continue;
151     }
152     return InvalidArgument(
153         "Einsum has rhs dimension without corresponding lhs or output "
154         "dimension.");
155   }
156   return Status::OK();
157 }
158 
Einsum(xla::XlaOp x,absl::Span<const int64> x_config,xla::XlaOp y,absl::Span<const int64> y_config,absl::Span<const int64> output_config,xla::PrecisionConfig::Precision precision)159 xla::XlaOp Einsum(xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
160                   absl::Span<const int64> y_config,
161                   absl::Span<const int64> output_config,
162                   xla::PrecisionConfig::Precision precision) {
163   XlaBuilder* builder = x.builder();
164   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
165     TF_RETURN_IF_ERROR(
166         ValidateEinsumNumericDimensions(x_config, y_config, output_config));
167     const int64 x_rank = x_config.size();
168     const int64 y_rank = y_config.size();
169     const int64 output_rank = output_config.size();
170     absl::flat_hash_set<int64> x_map;
171     absl::flat_hash_set<int64> y_map;
172     absl::flat_hash_set<int64> output_map;
173 
174     auto find = [&](const absl::flat_hash_set<int64>& map, int64 d) {
175       return map.count(d) != 0;
176     };
177 
178     auto insert = [&](absl::flat_hash_set<int64>& map, char d) {
179       CHECK(!find(map, d));
180       map.insert(d);
181     };
182 
183     for (auto d : x_config) {
184       insert(x_map, d);
185     }
186 
187     for (auto d : y_config) {
188       insert(y_map, d);
189     }
190 
191     for (auto d : output_config) {
192       insert(output_map, d);
193     }
194 
195     DotDimensionNumbers dnums;
196     std::vector<int64> lhs_outer_dims;
197     auto is_batch_dim = [&](int64 d) {
198       return find(x_map, d) && find(y_map, d) && find(output_map, d);
199     };
200     auto is_contracting = [&](int64 d) {
201       return find(x_map, d) && find(y_map, d);
202     };
203     auto rhs_dimension_number = [&](int64 d) {
204       return absl::c_find(y_config, d) - y_config.begin();
205     };
206     for (int64 i = 0; i < x_rank; ++i) {
207       auto dim_name = x_config[i];
208       if (is_batch_dim(dim_name)) {
209         dnums.add_lhs_batch_dimensions(i);
210         dnums.add_rhs_batch_dimensions(rhs_dimension_number(dim_name));
211       } else if (is_contracting(dim_name)) {
212         dnums.add_lhs_contracting_dimensions(i);
213         dnums.add_rhs_contracting_dimensions(rhs_dimension_number(dim_name));
214       } else {
215         lhs_outer_dims.push_back(i);
216       }
217     }
218 
219     std::vector<int64> rhs_outer_dims;
220     for (int64 i = 0; i < y_rank; ++i) {
221       auto dim_name = y_config[i];
222       if (!is_batch_dim(dim_name) && !is_contracting(dim_name)) {
223         rhs_outer_dims.push_back(i);
224       }
225     }
226 
227     auto output_dimension_number = [&](char d) {
228       return absl::c_find(output_config, d) - output_config.begin();
229     };
230 
231     std::vector<int64> output_dims;
232     output_dims.reserve(output_rank);
233     for (auto d : dnums.lhs_batch_dimensions()) {
234       output_dims.push_back(output_dimension_number(x_config[d]));
235     }
236     for (auto d : lhs_outer_dims) {
237       output_dims.push_back(output_dimension_number(x_config[d]));
238     }
239     for (auto d : rhs_outer_dims) {
240       output_dims.push_back(output_dimension_number(y_config[d]));
241     }
242 
243     std::vector<int64> transpose_dims(output_rank);
244     for (int64 i = 0; i < output_rank; ++i) {
245       transpose_dims[output_dims[i]] = i;
246     }
247 
248     PrecisionConfig precision_proto;
249     precision_proto.add_operand_precision(precision);
250     precision_proto.add_operand_precision(precision);
251     return Transpose(DotGeneral(x, y, dnums, &precision_proto), transpose_dims);
252   });
253 }
254 
BatchDot(XlaOp x,XlaOp y,PrecisionConfig::Precision precision)255 XlaOp BatchDot(XlaOp x, XlaOp y, PrecisionConfig::Precision precision) {
256   XlaBuilder* builder = x.builder();
257   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
258     TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
259     TF_ASSIGN_OR_RETURN(Shape y_shape, builder->GetShape(y));
260 
261     // The batch dimensions must be equal and the matrix dimensions must be
262     // valid.
263     std::vector<int64> batch_dimension_numbers;
264     const int ndims = x_shape.rank();
265     batch_dimension_numbers.reserve(ndims - 2);
266     for (int i = 0; i < ndims - 2; ++i) {
267       batch_dimension_numbers.push_back(i);
268     }
269     std::vector<int64> x_config = batch_dimension_numbers;
270     x_config.push_back(ndims - 2);
271     x_config.push_back(ndims);
272     std::vector<int64> y_config = batch_dimension_numbers;
273     y_config.push_back(ndims);
274     y_config.push_back(ndims - 1);
275     std::vector<int64> output_config = batch_dimension_numbers;
276     output_config.push_back(ndims - 2);
277     output_config.push_back(ndims - 1);
278     return Einsum(x, x_config, y, y_config, output_config, precision);
279   });
280 }
281 
ParseEinsumString(absl::string_view einsum_config)282 StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
283     absl::string_view einsum_config) {
284   std::array<std::vector<int64>, 3> einsum_config_numeric;
285   std::vector<absl::string_view> main_split =
286       absl::StrSplit(einsum_config, ',');
287 
288   if (main_split.size() != 2) {
289     return InvalidArgument("Expected one \",\" in einsum_config.");
290   }
291 
292   auto maybe_invalid_character = [](char d) {
293     if (absl::ascii_isalpha(d)) {
294       return Status::OK();
295     }
296     if (d == '.') {
297       return InvalidArgument("Unsupported \"...\" or \".\" in einsum config.");
298     }
299     return InvalidArgument("Unexpected character in einsum config.");
300   };
301 
302   auto& x_config = einsum_config_numeric[0];
303   x_config.reserve(main_split[0].size());
304   for (auto d : main_split[0]) {
305     TF_RETURN_IF_ERROR(maybe_invalid_character(d));
306     x_config.push_back(static_cast<int64>(d));
307   }
308   std::vector<absl::string_view> y_output_split =
309       absl::StrSplit(main_split[1], "->");
310   if (y_output_split.size() != 2) {
311     return InvalidArgument("Expected one \"->\" in einsum_config.");
312   }
313   auto& y_config = einsum_config_numeric[1];
314   y_config.reserve(y_output_split[0].size());
315   for (auto d : y_output_split[0]) {
316     TF_RETURN_IF_ERROR(maybe_invalid_character(d));
317     y_config.push_back(static_cast<int64>(d));
318   }
319   auto& output_config = einsum_config_numeric[2];
320   output_config.reserve(y_output_split[1].size());
321   for (auto d : y_output_split[1]) {
322     TF_RETURN_IF_ERROR(maybe_invalid_character(d));
323     output_config.push_back(static_cast<int64>(d));
324   }
325   return einsum_config_numeric;
326 }
327 
Einsum(XlaOp x,XlaOp y,absl::string_view einsum_config,PrecisionConfig::Precision precision)328 XlaOp Einsum(XlaOp x, XlaOp y, absl::string_view einsum_config,
329              PrecisionConfig::Precision precision) {
330   XlaBuilder* builder = x.builder();
331   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
332     TF_ASSIGN_OR_RETURN(auto einsum_config_numeric,
333                         ParseEinsumString(einsum_config));
334     return Einsum(x, einsum_config_numeric[0], y, einsum_config_numeric[1],
335                   einsum_config_numeric[2], precision);
336   });
337 }
338 
TransposeInMinorDims(XlaOp x)339 XlaOp TransposeInMinorDims(XlaOp x) {
340   XlaBuilder* builder = x.builder();
341   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
342     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x));
343     const int64 n_dims = shape.rank();
344     TF_RET_CHECK(n_dims >= 2);
345     std::vector<int64> permutation(n_dims);
346     std::iota(permutation.begin(), permutation.end(), 0);
347     std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
348     return Transpose(x, permutation);
349   });
350 }
351 
MaybeTransposeInMinorDims(XlaOp x,bool transpose)352 XlaOp MaybeTransposeInMinorDims(XlaOp x, bool transpose) {
353   return transpose ? TransposeInMinorDims(x) : x;
354 }
355 
356 }  // namespace xla
357