• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/util.h"
17 #include "tensorflow/compiler/tf2xla/type_util.h"
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/lib/matrix.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 // Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
CreateDiagonal(xla::XlaOp input,int64 last_dim_size,absl::Span<const int64> other_dims,xla::PrimitiveType element_type)31 xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
32                           absl::Span<const int64> other_dims,
33                           xla::PrimitiveType element_type) {
34   xla::XlaBuilder* builder = input.builder();
35   // Create two matrices that have the following forms, and compare them:
36   //
37   // [[0, 0, 0, 0]            [[0, 1, 2, 3]
38   //  [1, 1, 1, 1]             [0, 1, 2, 3]
39   //  [2, 2, 2, 2]             [0, 1, 2, 3]
40   //  [3, 3, 3, 3]]            [0, 1, 2, 3]]
41   //
42   // This produces a predicate matrix of the right size, with "true" on the
43   // diagonal.
44   xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size);
45   xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size});
46   xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0});
47 
48   // If this is a batched diagonal, broadcast the mask across the other
49   // dimensions.
50   if (!other_dims.empty()) {
51     mask = xla::Broadcast(mask, other_dims);
52   }
53 
54   // Broadcast the input, and then use the mask computed above to select the
55   // diagonal:
56   // e.g, in 2D:
57   //         [[t, f, f]    [[1, 1, 1]    [[0, 0, 0]      [[1, 0, 0]
58   // select(  [f, t, f]  ,  [4, 4, 4]  ,  [0, 0, 0]  ) =  [0, 4, 0]
59   //          [f, f, t]]    [9, 9, 9]]    [0, 0, 0]]      [0, 0, 9]]
60   //
61   // Broadcasting the input is less-than-trivial, since we need to broadcast
62   // into a "middle" dimension. We can do this with a reshape + implicit
63   // broadcast.
64   // TODO(b/30112114): Replace with in-dim broadcast when those are supported.
65   std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
66   broadcast_dims.push_back(1LL);
67   broadcast_dims.push_back(last_dim_size);
68   xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims);
69 
70   broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
71   auto broadcast_shape =
72       xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
73   xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape);
74 
75   input_broadcast = xla::Add(input_broadcast, zeros);
76   return xla::Select(mask, input_broadcast, zeros);
77 }
78 
79 class DiagOp : public XlaOpKernel {
80  public:
DiagOp(OpKernelConstruction * ctx)81   explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
82 
Compile(XlaOpKernelContext * ctx)83   void Compile(XlaOpKernelContext* ctx) override {
84     OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
85                 errors::InvalidArgument("Diag op must have at an input"));
86     const TensorShape input_shape = ctx->InputShape(0);
87 
88     auto dims = input_shape.dim_sizes();
89     OP_REQUIRES(ctx, !dims.empty(),
90                 errors::InvalidArgument("Expected 1 <= dims, got shape ",
91                                         input_shape.DebugString()));
92 
93     xla::XlaOp input = ctx->Input(0);
94 
95     // Picture:
96     // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
97     //                            [0, 2, 0, 0]
98     //                            [0, 0, 3, 0]
99     //                            [0, 0, 0, 4]]
100 
101     // Flattens the input to 1D.
102     int64 size = input_shape.num_elements();
103     input = xla::Reshape(input, {size});
104 
105     // Create an R2 with the R1 diagonal.
106     xla::XlaOp diag =
107         CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0));
108 
109     // Reshapes to the final shape.
110     std::vector<int64> new_dims(dims.size() * 2);
111     std::copy(dims.begin(), dims.end(), new_dims.begin());
112     std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
113     diag = xla::Reshape(diag, new_dims);
114 
115     ctx->SetOutput(0, diag);
116   }
117 };
118 
119 REGISTER_XLA_OP(Name("Diag"), DiagOp);
120 
121 class DiagPartOp : public XlaOpKernel {
122  public:
DiagPartOp(OpKernelConstruction * ctx)123   explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
124 
Compile(XlaOpKernelContext * ctx)125   void Compile(XlaOpKernelContext* ctx) override {
126     const TensorShape input_shape = ctx->InputShape(0);
127     auto dims = input_shape.dim_sizes();
128 
129     int num_dims = dims.size();
130     const int out_dims = num_dims / 2;
131 
132     OP_REQUIRES(ctx, 2 <= num_dims,
133                 errors::InvalidArgument("Expected 2 <= dims, got shape ",
134                                         input_shape.DebugString()));
135     OP_REQUIRES(ctx, num_dims % 2 == 0,
136                 errors::InvalidArgument("The input tensor must have even rank; "
137                                         "got shape ",
138                                         input_shape.DebugString()));
139     int64 new_size = 1;
140     std::vector<int64> new_dims;
141     for (int i = 0; i < out_dims; i++) {
142       OP_REQUIRES(
143           ctx, dims[i] == dims[i + out_dims],
144           errors::InvalidArgument("Invalid shape ", input_shape.DebugString(),
145                                   ": dimensions ", i, " and ", i + out_dims,
146                                   " do not match."));
147       new_size *= dims[i];
148       new_dims.push_back(dims[i]);
149     }
150 
151     xla::XlaOp input = ctx->Input(0);
152 
153     xla::XlaOp output = xla::Reshape(
154         xla::GetMatrixDiagonal(xla::Reshape(input, {new_size, new_size})),
155         new_dims);
156 
157     ctx->SetOutput(0, output);
158   }
159 };
160 
161 REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp);
162 
163 class MatrixDiagOp : public XlaOpKernel {
164  public:
MatrixDiagOp(OpKernelConstruction * ctx)165   explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
166 
Compile(XlaOpKernelContext * ctx)167   void Compile(XlaOpKernelContext* ctx) override {
168     OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
169                 errors::InvalidArgument("MatrixDiag op must have at an input"));
170     const TensorShape input_shape = ctx->InputShape(0);
171 
172     auto dims = input_shape.dim_sizes();
173     OP_REQUIRES(ctx, !dims.empty(),
174                 errors::InvalidArgument("Expected 1 <= dims, got shape ",
175                                         input_shape.DebugString()));
176 
177 
178     int last_dim = dims.size() - 1;
179     int64 last_dim_size = input_shape.dim_size(last_dim);
180     absl::Span<const int64> other_dims(dims);
181     other_dims.remove_suffix(1);
182 
183     xla::XlaOp input = ctx->Input(0);
184     xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
185                                      ctx->input_xla_type(0));
186     ctx->SetOutput(0, diag);
187   }
188 };
189 
190 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
191 
192 class MatrixDiagPartOp : public XlaOpKernel {
193  public:
MatrixDiagPartOp(OpKernelConstruction * ctx)194   explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
195 
Compile(XlaOpKernelContext * ctx)196   void Compile(XlaOpKernelContext* ctx) override {
197     const TensorShape input_shape = ctx->InputShape(0);
198     auto dims = input_shape.dim_sizes();
199 
200     OP_REQUIRES(ctx, 2 <= dims.size(),
201                 errors::InvalidArgument("Expected 2 <= dims, got shape ",
202                                         input_shape.DebugString()));
203 
204     xla::XlaOp input = ctx->Input(0);
205     ctx->SetOutput(0, xla::GetMatrixDiagonal(input));
206   }
207 };
208 
209 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
210 
211 }  // namespace
212 }  // namespace tensorflow
213