• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
18 
19 #include <array>
20 #include "absl/strings/string_view.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace xla {
28 
29 // Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere
30 // else.
31 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
32 
33 // Get the diagonals of the last two dimensions. Use k>0 for diagonals above the
34 // main diagonal, and k<0 for diagonals below the main diagonal.
35 //
36 // If 'x' has shape [..., M, N]
37 //  If k >= 0: then the output has shape [..., min(M, N - k)], containing the
38 //            diagonal elements (i.e., with indices [..., i, i + k]).
39 //  If k < 0: then the output has shape [..., min(M + k, N)], containing the
40 //            diagonal elements (i.e., with indices [..., i - k, i]).
41 XlaOp GetMatrixDiagonal(XlaOp x, int k = 0);
42 
43 // Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal
44 // and false above that diagonal.
45 XlaOp TriangleMask(XlaOp x, int diagonal);
46 
47 // Get the upper or lower triangle part of the last two dimensions
48 XlaOp Triangle(XlaOp x, bool lower);
49 
50 // Get the upper triangle part of the last two dimensions
51 XlaOp UpperTriangle(XlaOp x);
52 
53 // Get the lower triangle part of the last two dimensions
54 XlaOp LowerTriangle(XlaOp x);
55 
56 // Multiplies slices of two tensors in batches.
57 
58 // Multiplies all slices of `Tensor` `x` and `y` (each slice can be
59 // viewed as an element of a batch), and arranges the individual results
60 // in a single output tensor of the same batch size.
61 //
62 // The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
63 // and `[..., r_y, c_y]`.
64 //
65 // The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
66 //
67 //     r_o = c_x if transpose_x else r_x
68 //     c_o = r_y if transpose_y else c_y
69 //
70 // It is computed as:
71 //
72 //     output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
73 xla::XlaOp BatchDot(
74     xla::XlaOp x, xla::XlaOp y,
75     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
76 
77 // Parse an einsum string into dimension numbers:
78 //   "ab,cb->ac"
79 // becomes:
80 //   {{0, 1},{2, 1},{0, 2}}
81 //
82 // NOTE: This function is meant for testing, there is no need to call it
83 // directly.
84 
85 StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString(
86     absl::string_view einsum_config);
87 
88 // Determine if each dimension label is in at least two inputs.
89 //
90 // NOTE: This function is meant for testing, there is no need to call it
91 // directly.
92 Status ValidateEinsumNumericDimensions(absl::Span<const int64> x_config,
93                                        absl::Span<const int64> y_config,
94                                        absl::Span<const int64> output_config);
95 
96 // Supports two operand einsum notation like "ab,cb->ac".
97 xla::XlaOp Einsum(
98     xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config,
99     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
100 
101 // Same as above but supporting numeric labels on dimensins. So "ab,cb->ac"
102 // becomes:
103 //   x_config = {0, 1}
104 //   y_config = {2, 1}
105 //   output_config = {0, 2}
106 xla::XlaOp Einsum(
107     xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y,
108     absl::Span<const int64> y_config, absl::Span<const int64> output_config,
109     xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT);
110 
111 // Transposes a stack of matrices `x` by swapping the last two dimensions.
112 xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
113 
114 // Transposes `x` in its minor dimensions if `transpose` is true, otherwise
115 // returns `x` unchanged.
116 xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose);
117 
118 }  // namespace xla
119 
120 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_
121