• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Shape Ops.
17 
18 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
19 #include "tensorflow/compiler/tf2xla/type_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class ShapeOp : public XlaOpKernel {
33  public:
ShapeOp(OpKernelConstruction * ctx)34   explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
35     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
36   }
37 
Compile(XlaOpKernelContext * ctx)38   void Compile(XlaOpKernelContext* ctx) override {
39     const TensorShape input_shape = ctx->InputShape(0);
40     Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
41     OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
42     ctx->SetConstantOutput(0, shape_constant);
43   }
44 
45  private:
46   DataType out_dtype_;
47 };
48 
49 REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
50 
51 class ShapeNOp : public XlaOpKernel {
52  public:
ShapeNOp(OpKernelConstruction * ctx)53   explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
54     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
55   }
56 
Compile(XlaOpKernelContext * ctx)57   void Compile(XlaOpKernelContext* ctx) override {
58     for (int i = 0; i < ctx->num_inputs(); ++i) {
59       const TensorShape input_shape = ctx->InputShape(i);
60       Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
61       OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
62       ctx->SetConstantOutput(i, shape_constant);
63     }
64   }
65 
IsExpensive()66   bool IsExpensive() override { return false; }
67 
68  private:
69   DataType out_dtype_;
70 };
71 REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
72 
73 class RankOp : public XlaOpKernel {
74  public:
RankOp(OpKernelConstruction * ctx)75   explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
76 
Compile(XlaOpKernelContext * ctx)77   void Compile(XlaOpKernelContext* ctx) override {
78     const TensorShape input_shape = ctx->InputShape(0);
79     const int rank = input_shape.dims();
80     Tensor rank_constant(DT_INT32, TensorShape({}));
81     rank_constant.scalar<int32>()() = rank;
82 
83     ctx->SetConstantOutput(0, rank_constant);
84   }
85 };
86 
87 REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
88 
89 class SizeOp : public XlaOpKernel {
90  public:
SizeOp(OpKernelConstruction * ctx)91   explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
92 
Compile(XlaOpKernelContext * ctx)93   void Compile(XlaOpKernelContext* ctx) override {
94     const TensorShape input_shape = ctx->InputShape(0);
95     OP_REQUIRES(ctx,
96                 FastBoundsCheck(input_shape.num_elements(),
97                                 std::numeric_limits<int32>::max()),
98                 errors::InvalidArgument("Size does not work for tensors > "
99                                         "int32 max."));
100     Tensor size_constant(DT_INT32, TensorShape({}));
101     const int rank = input_shape.dims();
102     xla::XlaBuilder* builder = ctx->builder();
103     auto size = xla::One(builder, xla::U32);
104     for (int64 i = 0; i < rank; ++i) {
105       size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i));
106     }
107     size = xla::ConvertElementType(size, ctx->output_xla_type(0));
108     ctx->SetOutput(0, size);
109   }
110 };
111 
112 REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
113 
114 class ExpandDimsOp : public XlaOpKernel {
115  public:
ExpandDimsOp(OpKernelConstruction * ctx)116   explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
117 
Compile(XlaOpKernelContext * ctx)118   void Compile(XlaOpKernelContext* ctx) override {
119     const TensorShape input_shape = ctx->InputShape("input");
120     const TensorShape dim_shape = ctx->InputShape("dim");
121 
122     std::vector<int64> dims;
123     OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims));
124     OP_REQUIRES(ctx, dims.size() == 1,
125                 errors::InvalidArgument(absl::StrCat(
126                     "dim input to ExpandDims must be a scalar; got ",
127                     dim_shape.DebugString())));
128     int dim = dims[0];
129 
130     OP_REQUIRES(ctx,
131                 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
132                 errors::InvalidArgument("Tried to expand dim index ", dim,
133                                         " for tensor with ", input_shape.dims(),
134                                         " dimensions."));
135 
136     auto existing_dims = input_shape.dim_sizes();
137     // Safe - # elements in tensor dims bounded.
138     const int existing_dims_size = static_cast<int>(existing_dims.size());
139     std::vector<int64> new_shape(existing_dims_size);
140     for (size_t i = 0; i < new_shape.size(); ++i) {
141       new_shape[i] = existing_dims[i];
142     }
143 
144     // We emulate numpy's interpretation of the dim axis when
145     // -input.dims() >= dim <= input.dims().
146     if (dim < 0) {
147       dim += existing_dims.size() + 1;
148     }
149 
150     // Clamp to the end if needed.
151     dim = std::min<int32>(dim, existing_dims_size);
152     new_shape.emplace(new_shape.begin() + dim, 1);
153 
154     ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape));
155   }
156 };
157 REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"),
158                 ExpandDimsOp);
159 
160 class SqueezeOp : public XlaOpKernel {
161  public:
SqueezeOp(OpKernelConstruction * ctx)162   explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
163     std::vector<int32> squeeze_dims;
164     OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
165     squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
166   }
167 
Compile(XlaOpKernelContext * ctx)168   void Compile(XlaOpKernelContext* ctx) override {
169     const TensorShape input_shape = ctx->InputShape(0);
170     auto existing_dims = input_shape.dim_sizes();
171     int existing_dims_size = input_shape.dims();
172     std::vector<int64> new_shape;
173 
174     std::unordered_set<int32> wrapped_squeeze_dims;
175     wrapped_squeeze_dims.reserve(squeeze_dims_.size());
176     // Validate squeeze dims against the input.
177     for (int32 dim : squeeze_dims_) {
178       OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()),
179                   errors::InvalidArgument("Tried to squeeze dim index ", dim,
180                                           " for tensor with ",
181                                           input_shape.dims(), " dimensions."));
182       // If dim is < 0, we wrap around (-1 means the last element).
183       if (dim < 0) {
184         dim = existing_dims_size + dim;
185       }
186 
187       wrapped_squeeze_dims.insert(dim);
188     }
189 
190     for (int i = 0; i < existing_dims_size; ++i) {
191       auto existing_dim = existing_dims[i];
192 
193       // If squeeze_set is non-empty, only squeeze those dimensions.
194       if (!wrapped_squeeze_dims.empty()) {
195         if (wrapped_squeeze_dims.count(i) > 0) {
196           OP_REQUIRES(ctx, existing_dim == 1,
197                       errors::InvalidArgument(
198                           "Tried to explicitly squeeze dimension ", i,
199                           " but dimension was not 1: ", existing_dim));
200         } else {
201           // This dimension is not being squeezed.
202           new_shape.push_back(existing_dim);
203         }
204       } else {
205         // Copy over all non-1-length dimensions.
206         if (existing_dim != 1) {
207           new_shape.push_back(existing_dim);
208         }
209       }
210     }
211 
212     ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
213   }
214 
215  private:
216   std::unordered_set<int32> squeeze_dims_;
217 };
218 
219 REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp);
220 
221 class ZerosLikeOp : public XlaOpKernel {
222  public:
ZerosLikeOp(OpKernelConstruction * ctx)223   explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
224 
Compile(XlaOpKernelContext * ctx)225   void Compile(XlaOpKernelContext* ctx) override {
226     const TensorShape input_shape = ctx->InputShape(0);
227 
228     auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
229     ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
230   }
231 };
232 
233 REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp);
234 
235 class OnesLikeOp : public XlaOpKernel {
236  public:
OnesLikeOp(OpKernelConstruction * ctx)237   explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
238 
Compile(XlaOpKernelContext * ctx)239   void Compile(XlaOpKernelContext* ctx) override {
240     const TensorShape input_shape = ctx->InputShape(0);
241 
242     auto one = XlaHelpers::One(ctx->builder(), input_type(0));
243     ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes()));
244   }
245 };
246 
247 REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp);
248 
249 }  // namespace
250 }  // namespace tensorflow
251