1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/lib/util.h"
17 #include "tensorflow/compiler/tf2xla/type_util.h"
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26
27 namespace tensorflow {
28 namespace {
29
30 // Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
CreateDiagonal(xla::XlaOp input,int64 last_dim_size,absl::Span<const int64> other_dims,xla::PrimitiveType element_type)31 xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
32 absl::Span<const int64> other_dims,
33 xla::PrimitiveType element_type) {
34 xla::XlaBuilder* builder = input.builder();
35 // Create two matrices that have the following forms, and compare them:
36 //
37 // [[0, 0, 0, 0] [[0, 1, 2, 3]
38 // [1, 1, 1, 1] [0, 1, 2, 3]
39 // [2, 2, 2, 2] [0, 1, 2, 3]
40 // [3, 3, 3, 3]] [0, 1, 2, 3]]
41 //
42 // This produces a predicate matrix of the right size, with "true" on the
43 // diagonal.
44 xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size);
45 xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size});
46 xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0});
47
48 // If this is a batched diagonal, broadcast the mask across the other
49 // dimensions.
50 if (!other_dims.empty()) {
51 mask = xla::Broadcast(mask, other_dims);
52 }
53
54 // Broadcast the input, and then use the mask computed above to select the
55 // diagonal:
56 // e.g, in 2D:
57 // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0]
58 // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0]
59 // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]]
60 //
61 // Broadcasting the input is less-than-trivial, since we need to broadcast
62 // into a "middle" dimension. We can do this with a reshape + implicit
63 // broadcast.
64 // TODO(b/30112114): Replace with in-dim broadcast when those are supported.
65 std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
66 broadcast_dims.push_back(1LL);
67 broadcast_dims.push_back(last_dim_size);
68 xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims);
69
70 broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
71 auto broadcast_shape =
72 xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
73 xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape);
74
75 input_broadcast = xla::Add(input_broadcast, zeros);
76 return xla::Select(mask, input_broadcast, zeros);
77 }
78
79 class DiagOp : public XlaOpKernel {
80 public:
DiagOp(OpKernelConstruction * ctx)81 explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
82
Compile(XlaOpKernelContext * ctx)83 void Compile(XlaOpKernelContext* ctx) override {
84 OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
85 errors::InvalidArgument("Diag op must have at an input"));
86 const TensorShape input_shape = ctx->InputShape(0);
87
88 auto dims = input_shape.dim_sizes();
89 OP_REQUIRES(ctx, !dims.empty(),
90 errors::InvalidArgument("Expected 1 <= dims, got shape ",
91 input_shape.DebugString()));
92
93 xla::XlaOp input = ctx->Input(0);
94
95 // Picture:
96 // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
97 // [0, 2, 0, 0]
98 // [0, 0, 3, 0]
99 // [0, 0, 0, 4]]
100
101 // Flattens the input to 1D.
102 int64 size = input_shape.num_elements();
103 input = xla::Reshape(input, {size});
104
105 // Create an R2 with the R1 diagonal.
106 xla::XlaOp diag =
107 CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0));
108
109 // Reshapes to the final shape.
110 std::vector<int64> new_dims(dims.size() * 2);
111 std::copy(dims.begin(), dims.end(), new_dims.begin());
112 std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
113 diag = xla::Reshape(diag, new_dims);
114
115 ctx->SetOutput(0, diag);
116 }
117 };
118
119 REGISTER_XLA_OP(Name("Diag"), DiagOp);
120
121 class DiagPartOp : public XlaOpKernel {
122 public:
DiagPartOp(OpKernelConstruction * ctx)123 explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
124
Compile(XlaOpKernelContext * ctx)125 void Compile(XlaOpKernelContext* ctx) override {
126 const TensorShape input_shape = ctx->InputShape(0);
127 auto dims = input_shape.dim_sizes();
128
129 int num_dims = dims.size();
130 const int out_dims = num_dims / 2;
131
132 OP_REQUIRES(ctx, 2 <= num_dims,
133 errors::InvalidArgument("Expected 2 <= dims, got shape ",
134 input_shape.DebugString()));
135 OP_REQUIRES(ctx, num_dims % 2 == 0,
136 errors::InvalidArgument("The input tensor must have even rank; "
137 "got shape ",
138 input_shape.DebugString()));
139 int64 new_size = 1;
140 std::vector<int64> new_dims;
141 for (int i = 0; i < out_dims; i++) {
142 OP_REQUIRES(
143 ctx, dims[i] == dims[i + out_dims],
144 errors::InvalidArgument("Invalid shape ", input_shape.DebugString(),
145 ": dimensions ", i, " and ", i + out_dims,
146 " do not match."));
147 new_size *= dims[i];
148 new_dims.push_back(dims[i]);
149 }
150
151 xla::XlaOp input = ctx->Input(0);
152
153 xla::XlaOp output = xla::Reshape(
154 xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})),
155 new_dims);
156
157 ctx->SetOutput(0, output);
158 }
159 };
160
161 REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp);
162
163 class MatrixDiagOp : public XlaOpKernel {
164 public:
MatrixDiagOp(OpKernelConstruction * ctx)165 explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
166
Compile(XlaOpKernelContext * ctx)167 void Compile(XlaOpKernelContext* ctx) override {
168 OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
169 errors::InvalidArgument("MatrixDiag op must have at an input"));
170 const TensorShape input_shape = ctx->InputShape(0);
171
172 auto dims = input_shape.dim_sizes();
173 OP_REQUIRES(ctx, !dims.empty(),
174 errors::InvalidArgument("Expected 1 <= dims, got shape ",
175 input_shape.DebugString()));
176
177
178 int last_dim = dims.size() - 1;
179 int64 last_dim_size = input_shape.dim_size(last_dim);
180 absl::Span<const int64> other_dims(dims);
181 other_dims.remove_suffix(1);
182
183 xla::XlaOp input = ctx->Input(0);
184 xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
185 ctx->input_xla_type(0));
186 ctx->SetOutput(0, diag);
187 }
188 };
189
190 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
191
192 class MatrixDiagPartOp : public XlaOpKernel {
193 public:
MatrixDiagPartOp(OpKernelConstruction * ctx)194 explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
195
Compile(XlaOpKernelContext * ctx)196 void Compile(XlaOpKernelContext* ctx) override {
197 const TensorShape input_shape = ctx->InputShape(0);
198 auto dims = input_shape.dim_sizes();
199
200 OP_REQUIRES(ctx, 2 <= dims.size(),
201 errors::InvalidArgument("Expected 2 <= dims, got shape ",
202 input_shape.DebugString()));
203
204 xla::XlaOp input = ctx->Input(0);
205 ctx->SetOutput(0, xla::GetMatrixDiagonal(input));
206 }
207 };
208
209 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
210
211 } // namespace
212 } // namespace tensorflow
213