• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
18 
19 #include <limits>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/variant_op_registry.h"
28 
29 namespace tensorflow {
30 
31 namespace shape_op_helpers {
GetShape(OpKernelContext * ctx,int input_index,TensorShape * shape)32 inline Status GetShape(OpKernelContext* ctx, int input_index,
33                        TensorShape* shape) {
34   *shape = ctx->input(input_index).shape();
35   return Status::OK();
36 }
37 }  // namespace shape_op_helpers
38 
39 template <typename OutType>
40 class ShapeOp : public OpKernel {
41  public:
ShapeOp(OpKernelConstruction * ctx)42   explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
43 
Compute(OpKernelContext * ctx)44   void Compute(OpKernelContext* ctx) override {
45     TensorShape shape;
46     OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
47     const int rank = shape.dims();
48     Tensor* out = nullptr;
49     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
50     auto vec = out->vec<OutType>();
51     for (int i = 0; i < rank; ++i) {
52       int64 dim_size = shape.dim_size(i);
53       if (out->dtype() == DT_INT32) {
54         OP_REQUIRES(
55             ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
56             errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
57                                     i, " is ", dim_size));
58       }
59       vec(i) = static_cast<OutType>(dim_size);
60     }
61   }
62 
IsExpensive()63   bool IsExpensive() override { return false; }
64 };
65 
66 template <typename OutType>
67 class ShapeNOp : public OpKernel {
68  public:
ShapeNOp(OpKernelConstruction * ctx)69   explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
70 
Compute(OpKernelContext * ctx)71   void Compute(OpKernelContext* ctx) override {
72     for (int i = 0; i < ctx->num_inputs(); ++i) {
73       TensorShape shape;
74       OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, i, &shape));
75       const int dims = shape.dims();
76       Tensor* out = nullptr;
77       OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
78       auto vec = out->vec<OutType>();
79 
80       for (int j = 0; j < dims; ++j) {
81         int64 dim_size = shape.dim_size(j);
82         if (out->dtype() == DT_INT32) {
83           OP_REQUIRES(
84               ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
85               errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
86                                       i, " dim ", j, " is ", dim_size));
87         }
88         vec(j) = static_cast<OutType>(dim_size);
89       }
90     }
91   }
92 
IsExpensive()93   bool IsExpensive() override { return false; }
94 };
95 
96 class RankOp : public OpKernel {
97  public:
RankOp(OpKernelConstruction * ctx)98   explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
99 
Compute(OpKernelContext * ctx)100   void Compute(OpKernelContext* ctx) override {
101     TensorShape shape;
102     OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
103     const int rank = shape.dims();
104     Tensor* out = nullptr;
105     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
106     out->scalar<int32>()() = rank;
107   }
108 
IsExpensive()109   bool IsExpensive() override { return false; }
110 };
111 
112 template <typename OutType>
113 class SizeOp : public OpKernel {
114  public:
SizeOp(OpKernelConstruction * ctx)115   explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
116 
Compute(OpKernelContext * ctx)117   void Compute(OpKernelContext* ctx) override {
118     TensorShape shape;
119     OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
120     const int64 size = shape.num_elements();
121     Tensor* out = nullptr;
122     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
123     if (out->dtype() == DT_INT32) {
124       OP_REQUIRES(
125           ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
126           errors::InvalidArgument("Number of elements was larger than "
127                                   "representable by 32-bit output type"));
128     }
129     out->scalar<OutType>()() = static_cast<OutType>(size);
130   }
131 
IsExpensive()132   bool IsExpensive() override { return false; }
133 };
134 
135 template <typename Tdim>
136 class ExpandDimsOp : public OpKernel {
137  public:
ExpandDimsOp(OpKernelConstruction * ctx)138   explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
139 
Compute(OpKernelContext * ctx)140   void Compute(OpKernelContext* ctx) override {
141     OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
142                 errors::InvalidArgument("ExpandDims on Variant not supported"));
143 
144     OP_REQUIRES(
145         ctx, (ctx->input(1).NumElements() == 1),
146         errors::InvalidArgument("'dim' must be a tensor with a single value"));
147     Tdim dim = ctx->input(1).flat<Tdim>()(0);
148     OP_REQUIRES(
149         ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
150         errors::InvalidArgument("Tried to expand dim index ", dim,
151                                 " for tensor with ", ctx->input(0).dims(),
152                                 " dimensions."));
153 
154     auto existing_dims = ctx->input(0).shape().dim_sizes();
155     // Safe - # elements in tensor dims bounded.
156     const int existing_dims_size = static_cast<int>(existing_dims.size());
157     std::vector<int64> new_shape(existing_dims_size);
158     for (size_t i = 0; i < new_shape.size(); ++i) {
159       new_shape[i] = existing_dims[i];
160     }
161 
162     // We emulate numpy's interpretation of the dim axis when
163     // -input.dims() >= dim <= input.dims().
164     if (dim < 0) {
165       dim += existing_dims.size() + 1;
166     }
167 
168     // Clamp to the end if needed.
169     dim = std::min<Tdim>(dim, existing_dims_size);
170     new_shape.emplace(new_shape.begin() + dim, 1);
171     const TensorShape output_shape(new_shape);
172 
173     Tensor* output = nullptr;
174     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
175     if (!output->CopyFrom(ctx->input(0), output_shape)) {
176       // This should never happen, since the sizes of the input and output
177       // should always be the same (we only expand the dimension with 1).
178       ctx->SetStatus(
179           errors::Internal("Could not expand dimension with input shape ",
180                            ctx->input(0).shape().DebugString(),
181                            " and output shape ", output_shape.DebugString()));
182     }
183   }
184 
IsExpensive()185   bool IsExpensive() override { return false; }
186 };
187 
188 class SqueezeOp : public OpKernel {
189  public:
SqueezeOp(OpKernelConstruction * ctx)190   explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
191     std::vector<int32> squeeze_dims;
192     OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
193     squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
194   }
195 
Compute(OpKernelContext * ctx)196   void Compute(OpKernelContext* ctx) override {
197     OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
198                 errors::InvalidArgument("Squeeze on Variant not supported"));
199 
200     auto existing_dims = ctx->input(0).shape().dim_sizes();
201     const int existing_dims_size = static_cast<int>(existing_dims.size());
202     std::vector<int64> new_shape;
203 
204     std::unordered_set<int32> wrapped_squeeze_dims;
205     wrapped_squeeze_dims.reserve(squeeze_dims_.size());
206     // Validate squeeze dims against the input.
207     for (int32 dim : squeeze_dims_) {
208       OP_REQUIRES(
209           ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
210           errors::InvalidArgument("Tried to squeeze dim index ", dim,
211                                   " for tensor with ", ctx->input(0).dims(),
212                                   " dimensions."));
213       // If dim is < 0, we wrap around (-1 means the last element).
214       if (dim < 0) {
215         dim = existing_dims_size + dim;
216       }
217 
218       wrapped_squeeze_dims.insert(dim);
219     }
220 
221     for (int i = 0; i < existing_dims_size; ++i) {
222       auto existing_dim = existing_dims[i];
223 
224       // If squeeze_set is non-empty, only squeeze those dimensions.
225       if (!wrapped_squeeze_dims.empty()) {
226         if (wrapped_squeeze_dims.count(i) > 0) {
227           OP_REQUIRES(ctx, existing_dim == 1,
228                       errors::InvalidArgument(
229                           "Can not squeeze dim[", i,
230                           "], expected a dimension of 1, got ", existing_dim));
231         } else {
232           // This dimension is not being squeezed.
233           new_shape.push_back(existing_dim);
234         }
235       } else {
236         // Copy over all non-1-length dimensions.
237         if (existing_dim != 1) {
238           new_shape.push_back(existing_dim);
239         }
240       }
241     }
242 
243     const TensorShape output_shape(new_shape);
244     Tensor* output = nullptr;
245     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
246     if (!output->CopyFrom(ctx->input(0), output_shape)) {
247       // This should never happen, since the sizes of the input and
248       // output should always be the same.
249       ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
250                                       ctx->input(0).shape().DebugString(),
251                                       " and output shape ",
252                                       output_shape.DebugString()));
253     }
254   }
255 
IsExpensive()256   bool IsExpensive() override { return false; }
257 
258  private:
259   std::unordered_set<int32> squeeze_dims_;
260 };
261 
262 }  // namespace tensorflow
263 
264 #endif  // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
265