1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ 17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ 18 19 #include <array> 20 #include <vector> 21 22 #include "absl/strings/string_view.h" 23 #include "absl/types/span.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/compiler/xla/statusor.h" 26 #include "tensorflow/compiler/xla/types.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 29 namespace xla { 30 31 // Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere 32 // else. 33 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n); 34 35 // Returns a mask where the 'diagonal'-th diagonal is true and everything else 36 // is false. 37 XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); 38 39 // Get the diagonals of the last two dimensions. Use k>0 for diagonals above the 40 // main diagonal, and k<0 for diagonals below the main diagonal. 41 // 42 // If 'x' has shape [..., M, N] 43 // If k >= 0: then the output has shape [..., min(M, N - k)], containing the 44 // diagonal elements (i.e., with indices [..., i, i + k]). 45 // If k < 0: then the output has shape [..., min(M + k, N)], containing the 46 // diagonal elements (i.e., with indices [..., i - k, i]). 47 XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); 48 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); 49 50 // Places diag along the kth diagonal of target. 51 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); 52 53 // Returns a lower-triangular mask, i.e., true below the `diagonal`-th diagonal 54 // and false above that diagonal. 55 XlaOp TriangleMask(XlaOp x, int diagonal); 56 57 // Get the upper or lower triangle part of the last two dimensions 58 XlaOp Triangle(XlaOp x, bool lower); 59 60 // Get the upper triangle part of the last two dimensions 61 XlaOp UpperTriangle(XlaOp x); 62 63 // Get the lower triangle part of the last two dimensions 64 XlaOp LowerTriangle(XlaOp x); 65 66 // Multiplies slices of two tensors in batches. 67 68 // Multiplies all slices of `Tensor` `x` and `y` (each slice can be 69 // viewed as an element of a batch), and arranges the individual results 70 // in a single output tensor of the same batch size. 71 // 72 // The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` 73 // and `[..., r_y, c_y]`. 74 // 75 // The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: 76 // 77 // r_o = c_x if transpose_x else r_x 78 // c_o = r_y if transpose_y else c_y 79 // 80 // It is computed as: 81 // 82 // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) 83 xla::XlaOp BatchDot( 84 xla::XlaOp x, xla::XlaOp y, 85 xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); 86 xla::XlaOp BatchDot( 87 xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, 88 xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); 89 90 // Parse an einsum string into dimension numbers: 91 // "ab,cb->ac" 92 // becomes: 93 // {{0, 1},{2, 1},{0, 2}} 94 // 95 // Each occurrence of ellipsis ("...") occurring in the input is replaced with 96 // the same numeric dimensions. The number of such dimensions is inferred from 97 // x_rank and y_rank. For example: 98 // einsum_config: "...ab,...bcd->...acd" 99 // x_rank: 4 100 // y_rank: 5 101 // becomes: 102 // {{0, 1, 2, 3},{0, 1, 3, 4, 5},{0, 1, 2, 4, 5}} 103 // 104 // NOTE: This function is meant for testing, there is no need to call it 105 // directly. 106 107 StatusOr<std::array<std::vector<int64>, 3>> ParseEinsumString( 108 absl::string_view einsum_config, int64 x_rank, int64 y_rank); 109 110 // If an einsum config does not contain an -> one will be added and the output 111 // config will be the sorted characters with any ellipsis at the beginning. 112 // Returns an empty string if the einsum string already has an ->. 113 std::string NormalizeEinsumString(absl::string_view einsum_config); 114 115 // Supports two operand einsum notation like "ab,cb->ac". 116 xla::XlaOp Einsum( 117 xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, 118 xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); 119 xla::XlaOp Einsum( 120 xla::XlaOp x, absl::string_view einsum_config, 121 xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); 122 123 124 // Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" 125 // becomes: 126 // x_config = {0, 1} 127 // y_config = {2, 1} 128 // output_config = {0, 2} 129 xla::XlaOp Einsum( 130 xla::XlaOp x, absl::Span<const int64> x_config, xla::XlaOp y, 131 absl::Span<const int64> y_config, absl::Span<const int64> output_config, 132 xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); 133 134 // Transposes a stack of matrices `x` by swapping the last two dimensions. 135 xla::XlaOp TransposeInMinorDims(xla::XlaOp x); 136 137 // Transposes `x` in its minor dimensions if `transpose` is true, otherwise 138 // returns `x` unchanged. 139 xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); 140 141 } // namespace xla 142 143 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATRIX_H_ 144