• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Shape Ops.
17 
18 #include "absl/strings/str_format.h"
19 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/type_util.h"
23 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/core/framework/bounds_check.h"
31 #include "tensorflow/core/framework/kernel_def_builder.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 
35 namespace tensorflow {
36 namespace {
37 
38 class ShapeOp : public XlaOpKernel {
39  public:
ShapeOp(OpKernelConstruction * ctx)40   explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
41     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
42   }
43 
Compile(XlaOpKernelContext * ctx)44   void Compile(XlaOpKernelContext* ctx) override {
45     const TensorShape input_shape = ctx->InputShape(0);
46     std::vector<xla::XlaOp> operands;
47     const int rank = input_shape.dims();
48     if (rank != 0) {
49       for (int64_t i = 0; i < rank; ++i) {
50         operands.push_back(xla::Broadcast(
51             xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
52                                     ctx->output_xla_type(0)),
53             {1}));
54       }
55 
56       ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), operands, 0));
57     } else {
58       // Rank 0 won't have dynamic size dimension, use constant output.
59       Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
60       OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
61       ctx->SetConstantOutput(0, shape_constant);
62     }
63   }
64 
65  private:
66   DataType out_dtype_;
67 };
68 
69 REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
70 
71 class XlaSetBoundOp : public XlaOpKernel {
72  public:
XlaSetBoundOp(OpKernelConstruction * context)73   explicit XlaSetBoundOp(OpKernelConstruction* context)
74       : XlaOpKernel(context) {}
75 
Compile(XlaOpKernelContext * ctx)76   void Compile(XlaOpKernelContext* ctx) override {
77     const TensorShape input_shape = ctx->InputShape("input");
78     const TensorShape bound_shape = ctx->InputShape("bound");
79 
80     OP_REQUIRES(
81         ctx,
82         ctx->InputType("bound") == DT_INT32 &&
83             ctx->InputType("input") == DT_INT32,
84         errors::InvalidArgument(
85             "XlaSetBound can only set bound for int32 scalar value: got",
86             input_shape.DebugString()));
87 
88     OP_REQUIRES(
89         ctx, input_shape.dims() == 0,
90         errors::InvalidArgument("XlaSetBound should only be used to set a "
91                                 "bound to the an int32 scalar value: got",
92                                 input_shape.DebugString()));
93 
94     OP_REQUIRES(
95         ctx, bound_shape.dims() == 0,
96         errors::InvalidArgument("XlaSetBound should only be used to set a "
97                                 "bound to the an int32 scalar value: got",
98                                 bound_shape.DebugString()));
99     int64_t bound;
100     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound));
101     xla::Literal bound_literal = xla::LiteralUtil::CreateR0<int32>(bound);
102     xla::XlaOp result =
103         xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")},
104                         ctx->InputXlaShape("input").ValueOrDie(), "", false, {},
105                         &bound_literal);
106     ctx->SetOutput(0, result);
107   }
108 };
109 
110 REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"),
111                 XlaSetBoundOp);
112 
113 class XlaSetDynamicDimensionSizeOp : public XlaOpKernel {
114  public:
XlaSetDynamicDimensionSizeOp(OpKernelConstruction * context)115   explicit XlaSetDynamicDimensionSizeOp(OpKernelConstruction* context)
116       : XlaOpKernel(context) {}
117 
Compile(XlaOpKernelContext * ctx)118   void Compile(XlaOpKernelContext* ctx) override {
119     const TensorShape dim_index_shape = ctx->InputShape("dim_index");
120     const TensorShape size_shape = ctx->InputShape("size");
121 
122     OP_REQUIRES(ctx,
123                 ctx->InputType("dim_index") == DT_INT32 &&
124                     ctx->InputType("size") == DT_INT32,
125                 errors::InvalidArgument("dim_index and size has to be int32 for"
126                                         "XlaSetDynamicDimensionSizeOp"));
127 
128     OP_REQUIRES(
129         ctx, dim_index_shape.dims() == 0 && size_shape.dims() == 0,
130         errors::InvalidArgument("XlaSetDynamicDimensionSizeOp's dim_index and "
131                                 "size has to be int32 scalar value"));
132     int64_t dim_index;
133     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("dim_index", &dim_index));
134 
135     xla::XlaOp result =
136         xla::SetDimensionSize(ctx->Input(0), ctx->Input("size"), dim_index);
137     ctx->SetOutput(0, result);
138   }
139 };
140 
141 REGISTER_XLA_OP(
142     Name("XlaSetDynamicDimensionSize").CompileTimeConstantInput("dim_index"),
143     XlaSetDynamicDimensionSizeOp);
144 
145 class XlaRemoveDynamicDimensionSizeOp : public XlaOpKernel {
146  public:
XlaRemoveDynamicDimensionSizeOp(OpKernelConstruction * context)147   explicit XlaRemoveDynamicDimensionSizeOp(OpKernelConstruction* context)
148       : XlaOpKernel(context) {}
149 
Compile(XlaOpKernelContext * ctx)150   void Compile(XlaOpKernelContext* ctx) override {
151     const TensorShape dim_index_shape = ctx->InputShape("dim_index");
152 
153     OP_REQUIRES(ctx, ctx->InputType("dim_index") == DT_INT32,
154                 errors::InvalidArgument("dim_index has to be int32 for"
155                                         "XlaRemoveDynamicDimensionSizeOp"));
156 
157     OP_REQUIRES(
158         ctx, dim_index_shape.dims() == 0,
159         errors::InvalidArgument("XlaRemoveDynamicDimensionSizeOp's dim_index "
160                                 "has to be int32 scalar value"));
161     int64_t dim_index;
162     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("dim_index", &dim_index));
163 
164     xla::XlaOp result = xla::RemoveDynamicDimension(ctx->Input(0), dim_index);
165     ctx->SetOutput(0, result);
166   }
167 };
168 
169 REGISTER_XLA_OP(
170     Name("XlaRemoveDynamicDimensionSize").CompileTimeConstantInput("dim_index"),
171     XlaRemoveDynamicDimensionSizeOp);
172 
173 class ShapeNOp : public XlaOpKernel {
174  public:
ShapeNOp(OpKernelConstruction * ctx)175   explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
176     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
177   }
178 
Compile(XlaOpKernelContext * ctx)179   void Compile(XlaOpKernelContext* ctx) override {
180     for (int i = 0; i < ctx->num_inputs(); ++i) {
181       const TensorShape input_shape = ctx->InputShape(i);
182       std::vector<xla::XlaOp> operands;
183 
184       const int rank = input_shape.dims();
185       if (rank != 0) {
186         // Each dimension can be dynamic, so use GetDimensionSize to get the
187         // runtime dimension.
188         for (int64_t dim = 0; dim < rank; ++dim) {
189           operands.push_back(xla::Broadcast(
190               xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(i), dim),
191                                       ctx->output_xla_type(i)),
192               {1}));
193         }
194 
195         ctx->SetOutput(i, xla::ConcatInDim(ctx->builder(), operands, 0));
196       } else {
197         // Rank 0 won't have dynamic size dimension, use constant output.
198         Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
199         OP_REQUIRES_OK(ctx,
200                        TensorShapeToConstant(input_shape, &shape_constant));
201         ctx->SetConstantOutput(i, shape_constant);
202       }
203     }
204   }
205 
IsExpensive()206   bool IsExpensive() override { return false; }
207 
208  private:
209   DataType out_dtype_;
210 };
211 REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
212 
213 class RankOp : public XlaOpKernel {
214  public:
RankOp(OpKernelConstruction * ctx)215   explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
216 
Compile(XlaOpKernelContext * ctx)217   void Compile(XlaOpKernelContext* ctx) override {
218     const TensorShape input_shape = ctx->InputShape(0);
219     const int rank = input_shape.dims();
220     Tensor rank_constant(DT_INT32, TensorShape({}));
221     rank_constant.scalar<int32>()() = rank;
222 
223     ctx->SetConstantOutput(0, rank_constant);
224   }
225 };
226 
227 REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
228 
229 class SizeOp : public XlaOpKernel {
230  public:
SizeOp(OpKernelConstruction * ctx)231   explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
232 
Compile(XlaOpKernelContext * ctx)233   void Compile(XlaOpKernelContext* ctx) override {
234     const TensorShape input_shape = ctx->InputShape(0);
235     OP_REQUIRES(ctx,
236                 FastBoundsCheck(input_shape.num_elements(),
237                                 std::numeric_limits<int32>::max()),
238                 errors::InvalidArgument("Size does not work for tensors > "
239                                         "int32 max."));
240     Tensor size_constant(DT_INT32, TensorShape({}));
241     const int rank = input_shape.dims();
242     xla::XlaBuilder* builder = ctx->builder();
243     auto size = xla::One(builder, xla::U32);
244     for (int64_t i = 0; i < rank; ++i) {
245       size = xla::Mul(
246           size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i),
247                                         xla::U32));
248     }
249     size = xla::ConvertElementType(size, ctx->output_xla_type(0));
250     ctx->SetOutput(0, size);
251   }
252 };
253 
254 REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
255 
256 class ExpandDimsOp : public XlaOpKernel {
257  public:
ExpandDimsOp(OpKernelConstruction * ctx)258   explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
259 
Compile(XlaOpKernelContext * ctx)260   void Compile(XlaOpKernelContext* ctx) override {
261     const TensorShape input_shape = ctx->InputShape("input");
262     const TensorShape dim_shape = ctx->InputShape("dim");
263 
264     std::vector<int64_t> dims;
265     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
266     OP_REQUIRES(ctx, dims.size() == 1,
267                 errors::InvalidArgument(absl::StrCat(
268                     "dim input to ExpandDims must be a scalar; got ",
269                     dim_shape.DebugString())));
270     int dim = dims[0];
271 
272     OP_REQUIRES(ctx,
273                 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
274                 errors::InvalidArgument("Tried to expand dim index ", dim,
275                                         " for tensor with ", input_shape.dims(),
276                                         " dimensions."));
277 
278     auto existing_dims = input_shape.dim_sizes();
279     // Safe - # elements in tensor dims bounded.
280     const int existing_dims_size = static_cast<int>(existing_dims.size());
281     std::vector<int64_t> new_shape(existing_dims_size);
282     for (size_t i = 0; i < new_shape.size(); ++i) {
283       new_shape[i] = existing_dims[i];
284     }
285 
286     // We emulate numpy's interpretation of the dim axis when
287     // -input.dims() >= dim <= input.dims().
288     if (dim < 0) {
289       dim += existing_dims.size() + 1;
290     }
291 
292     // Clamp to the end if needed.
293     dim = std::min<int32>(dim, existing_dims_size);
294     new_shape.emplace(new_shape.begin() + dim, 1);
295 
296     ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
297   }
298 };
299 REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),
300                 ExpandDimsOp);
301 
302 class SqueezeOp : public XlaOpKernel {
303  public:
SqueezeOp(OpKernelConstruction * ctx)304   explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
305     std::vector<int32> squeeze_dims;
306     OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
307     squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
308   }
309 
Compile(XlaOpKernelContext * ctx)310   void Compile(XlaOpKernelContext* ctx) override {
311     StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(ctx->Input(0));
312     OP_REQUIRES_OK(ctx, input_shape.status());
313     xla::Shape shape = input_shape.ValueOrDie();
314     int64_t rank = shape.rank();
315 
316     std::unordered_set<int32> wrapped_squeeze_dims;
317     wrapped_squeeze_dims.reserve(squeeze_dims_.size());
318     std::vector<int64_t> new_shape;
319     // Validate squeeze dims against the input.
320     for (int32_t dim : squeeze_dims_) {
321       OP_REQUIRES(
322           ctx, (dim >= -rank && dim < rank),
323           errors::InvalidArgument("Tried to squeeze dim index ", dim,
324                                   " for tensor with ", rank, " dimensions."));
325       // If dim is < 0, we wrap around (-1 means the last element).
326       if (dim < 0) {
327         dim = rank + dim;
328       }
329 
330       wrapped_squeeze_dims.insert(dim);
331     }
332 
333     for (int i = 0; i < rank; ++i) {
334       auto existing_dim = shape.dimensions(i);
335 
336       // If squeeze_set is non-empty, only squeeze those dimensions.
337       if (!wrapped_squeeze_dims.empty()) {
338         if (wrapped_squeeze_dims.count(i) > 0) {
339           OP_REQUIRES(ctx, existing_dim == 1,
340                       errors::InvalidArgument(
341                           "Tried to explicitly squeeze dimension ", i,
342                           " but dimension was not 1: ", existing_dim));
343         } else {
344           // This dimension is not being squeezed.
345           new_shape.push_back(existing_dim);
346         }
347       } else {
348         OP_REQUIRES(
349             ctx, !shape.is_dynamic_dimension(i),
350             errors::InvalidArgument("Squeeze op does not support bounded "
351                                     "dynamic dimensions. Input shape: ",
352                                     shape.DebugString()));
353         // Copy over all non-1-length dimensions.
354         if (existing_dim != 1) {
355           new_shape.push_back(existing_dim);
356         }
357       }
358     }
359 
360     ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
361   }
362 
363  private:
364   std::unordered_set<int32> squeeze_dims_;
365 };
366 
367 REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp);
368 
369 class ZerosLikeOp : public XlaOpKernel {
370  public:
ZerosLikeOp(OpKernelConstruction * ctx)371   explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
372 
Compile(XlaOpKernelContext * ctx)373   void Compile(XlaOpKernelContext* ctx) override {
374     if (IsTensorListInput(ctx, 0)) {
375       // Input is a TensorList.
376 
377       // Check the TensorList input is initialized.
378       xla::XlaOp list = ctx->Input(0);
379       bool is_initialized;
380       OP_REQUIRES_OK(ctx, IsTensorListInitialized(list, &is_initialized));
381       OP_REQUIRES(
382           ctx, is_initialized,
383           errors::InvalidArgument(
384               "TensorList input for ZerosLike op is an uninitialized list"));
385 
386       auto list_shape_or = ctx->builder()->GetShape(list);
387       OP_REQUIRES_OK(ctx, list_shape_or.status());
388       const xla::Shape& list_shape = list_shape_or.ValueOrDie();
389       std::vector<std::vector<xla::XlaOp>> list_dynamic_dims;
390       list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1);
391       for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) {
392         // Set dynamic dimension size to 0 for initialization value.
393         std::vector<xla::XlaOp> dynamic_dims;
394         const xla::Shape& shape = list_shape.tuple_shapes(i);
395         auto sub_element = xla::GetTupleElement(list, i);
396         for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) {
397           dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim));
398         }
399         list_dynamic_dims.push_back(dynamic_dims);
400       }
401       xla::XlaOp new_list;
402       OP_REQUIRES_OK(
403           ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape,
404                                               list_dynamic_dims, &new_list));
405 
406       xla::XlaOp push_index;
407       OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index));
408 
409       xla::XlaOp result;
410       OP_REQUIRES_OK(ctx,
411                      SetTensorListPushIndex(new_list, push_index, &result));
412       ctx->SetTensorListOutput(0, result);
413     } else {
414       auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
415       xla::XlaOp input = ctx->Input(0);
416       auto input_shape = ctx->InputXlaShape(0).ValueOrDie();
417       auto result = xla::Broadcast(zero, input_shape.dimensions());
418 
419       // Setting up dynamic dimensions of the broadcast.
420       for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) {
421         if (input_shape.is_dynamic_dimension(i)) {
422           xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i);
423           result = xla::SetDimensionSize(result, input_dynamic_dim, i);
424         }
425       }
426 
427       ctx->SetOutput(0, result);
428     }
429   }
430 };
431 
432 REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp);
433 
434 class OnesLikeOp : public XlaOpKernel {
435  public:
OnesLikeOp(OpKernelConstruction * ctx)436   explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
437 
Compile(XlaOpKernelContext * ctx)438   void Compile(XlaOpKernelContext* ctx) override {
439     const TensorShape input_shape = ctx->InputShape(0);
440 
441     auto one = XlaHelpers::One(ctx->builder(), input_type(0));
442     ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes()));
443   }
444 };
445 
446 REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
447 
448 }  // namespace
449 }  // namespace tensorflow
450