• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
18 
19 #include <limits>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "tensorflow/core/common_runtime/dma_helper.h"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/variant_op_registry.h"
30 
31 namespace tensorflow {
32 
33 namespace shape_op_helpers {
GetShape(OpKernelContext * ctx,int input_index,TensorShape * shape)34 inline Status GetShape(OpKernelContext* ctx, int input_index,
35                        TensorShape* shape) {
36   *shape = ctx->input(input_index).shape();
37   return Status::OK();
38 }
39 }  // namespace shape_op_helpers
40 
41 template <typename OutType>
42 class ShapeOp : public OpKernel {
43  public:
ShapeOp(OpKernelConstruction * ctx)44   explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
45 
Compute(OpKernelContext * ctx)46   void Compute(OpKernelContext* ctx) override {
47     TensorShape shape;
48     OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
49     const int rank = shape.dims();
50     Tensor* out = nullptr;
51     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
52     auto vec = out->vec<OutType>();
53     for (int i = 0; i < rank; ++i) {
54       int64_t dim_size = shape.dim_size(i);
55       if (out->dtype() == DT_INT32) {
56         OP_REQUIRES(
57             ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
58             errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
59                                     i, " is ", dim_size));
60       }
61       vec(i) = static_cast<OutType>(dim_size);
62     }
63   }
64 
IsExpensive()65   bool IsExpensive() override { return false; }
66 };
67 
68 template <typename OutType>
69 class ShapeNOp : public OpKernel {
70  public:
ShapeNOp(OpKernelConstruction * ctx)71   explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
72 
Compute(OpKernelContext * ctx)73   void Compute(OpKernelContext* ctx) override {
74     for (int i = 0; i < ctx->num_inputs(); ++i) {
75       TensorShape shape;
76       OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, i, &shape));
77       const int dims = shape.dims();
78       Tensor* out = nullptr;
79       OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
80       auto vec = out->vec<OutType>();
81 
82       for (int j = 0; j < dims; ++j) {
83         int64_t dim_size = shape.dim_size(j);
84         if (out->dtype() == DT_INT32) {
85           OP_REQUIRES(
86               ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
87               errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
88                                       i, " dim ", j, " is ", dim_size));
89         }
90         vec(j) = static_cast<OutType>(dim_size);
91       }
92     }
93   }
94 
IsExpensive()95   bool IsExpensive() override { return false; }
96 };
97 
98 class RankOp : public OpKernel {
99  public:
RankOp(OpKernelConstruction * ctx)100   explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
101 
Compute(OpKernelContext * ctx)102   void Compute(OpKernelContext* ctx) override {
103     TensorShape shape;
104     OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
105     const int rank = shape.dims();
106     Tensor* out = nullptr;
107     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
108     out->scalar<int32>()() = rank;
109   }
110 
IsExpensive()111   bool IsExpensive() override { return false; }
112 };
113 
114 template <typename OutType>
115 class SizeOp : public OpKernel {
116  public:
SizeOp(OpKernelConstruction * ctx)117   explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
118 
Compute(OpKernelContext * ctx)119   void Compute(OpKernelContext* ctx) override {
120     TensorShape shape;
121     OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
122     const int64_t size = shape.num_elements();
123     Tensor* out = nullptr;
124     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
125     if (out->dtype() == DT_INT32) {
126       OP_REQUIRES(
127           ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
128           errors::InvalidArgument("Number of elements was larger than "
129                                   "representable by 32-bit output type"));
130     }
131     out->scalar<OutType>()() = static_cast<OutType>(size);
132   }
133 
IsExpensive()134   bool IsExpensive() override { return false; }
135 };
136 
137 template <typename Tdim>
138 class ExpandDimsOp : public OpKernel {
139  public:
ExpandDimsOp(OpKernelConstruction * ctx)140   explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
141 
Compute(OpKernelContext * ctx)142   void Compute(OpKernelContext* ctx) override {
143     const Tensor& input_t = ctx->input(0);
144     OP_REQUIRES(ctx, input_t.dtype() != DT_VARIANT,
145                 errors::InvalidArgument("ExpandDims on Variant not supported"));
146 
147     const Tensor& dim_t = ctx->input(1);
148     OP_REQUIRES(
149         ctx, (dim_t.NumElements() == 1),
150         errors::InvalidArgument("'dim' must be a tensor with a single value"));
151     DCHECK_EQ(dim_t.dtype(), DataTypeToEnum<Tdim>::v());
152     Tdim dim = *static_cast<const Tdim*>(DMAHelper::base(&dim_t));
153     const TensorShape& input_shape = input_t.shape();
154     int input_dims = input_shape.dims();
155     OP_REQUIRES(ctx, dim >= -1 - input_dims && dim <= input_dims,
156                 errors::InvalidArgument("Tried to expand dim index ", dim,
157                                         " for tensor with ", input_dims,
158                                         " dimensions."));
159 
160     // We emulate numpy's interpretation of the dim axis when
161     // -input.dims() >= dim <= input.dims().
162     if (dim < 0) {
163       // Clamp to the end if needed.
164       dim = std::min<Tdim>(dim + input_dims + 1, input_dims);
165     }
166 
167     // Compute new shape with an additional dimension.
168     absl::InlinedVector<int64, 8> output_shape_vec(input_dims + 1);
169     for (int64_t i = 0; i < dim; ++i) {
170       output_shape_vec[i] = input_shape.dim_size(i);
171     }
172     output_shape_vec[dim] = 1;
173     for (int64_t i = dim + 1; i < input_dims + 1; ++i) {
174       output_shape_vec[i] = input_shape.dim_size(i - 1);
175     }
176     TensorShape output_shape(output_shape_vec);
177 
178     Tensor output_t;
179     if (!output_t.CopyFrom(input_t, output_shape)) {
180       // This should never happen, since the sizes of the input and output
181       // should always be the same (we only expand the dimension with 1).
182       ctx->SetStatus(
183           errors::Internal("Could not expand dimension with input shape ",
184                            ctx->input(0).shape().DebugString(),
185                            " and output shape ", output_shape.DebugString()));
186     }
187     ctx->set_output(0, std::move(output_t));
188   }
189 
IsExpensive()190   bool IsExpensive() override { return false; }
191 };
192 
193 class SqueezeOp : public OpKernel {
194  public:
SqueezeOp(OpKernelConstruction * ctx)195   explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
196     std::vector<int32> squeeze_dims;
197     OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
198     squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
199   }
200 
Compute(OpKernelContext * ctx)201   void Compute(OpKernelContext* ctx) override {
202     OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
203                 errors::InvalidArgument("Squeeze on Variant not supported"));
204 
205     auto existing_dims = ctx->input(0).shape().dim_sizes();
206     const int existing_dims_size = static_cast<int>(existing_dims.size());
207     std::vector<int64> new_shape;
208 
209     std::unordered_set<int32> wrapped_squeeze_dims;
210     wrapped_squeeze_dims.reserve(squeeze_dims_.size());
211     // Validate squeeze dims against the input.
212     for (int32_t dim : squeeze_dims_) {
213       OP_REQUIRES(
214           ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
215           errors::InvalidArgument("Tried to squeeze dim index ", dim,
216                                   " for tensor with ", ctx->input(0).dims(),
217                                   " dimensions."));
218       // If dim is < 0, we wrap around (-1 means the last element).
219       if (dim < 0) {
220         dim = existing_dims_size + dim;
221       }
222 
223       wrapped_squeeze_dims.insert(dim);
224     }
225 
226     for (int i = 0; i < existing_dims_size; ++i) {
227       auto existing_dim = existing_dims[i];
228 
229       // If squeeze_set is non-empty, only squeeze those dimensions.
230       if (!wrapped_squeeze_dims.empty()) {
231         if (wrapped_squeeze_dims.count(i) > 0) {
232           OP_REQUIRES(ctx, existing_dim == 1,
233                       errors::InvalidArgument(
234                           "Can not squeeze dim[", i,
235                           "], expected a dimension of 1, got ", existing_dim));
236         } else {
237           // This dimension is not being squeezed.
238           new_shape.push_back(existing_dim);
239         }
240       } else {
241         // Copy over all non-1-length dimensions.
242         if (existing_dim != 1) {
243           new_shape.push_back(existing_dim);
244         }
245       }
246     }
247 
248     const TensorShape output_shape(new_shape);
249     Tensor* output = nullptr;
250     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
251     if (!output->CopyFrom(ctx->input(0), output_shape)) {
252       // This should never happen, since the sizes of the input and
253       // output should always be the same.
254       ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
255                                       ctx->input(0).shape().DebugString(),
256                                       " and output shape ",
257                                       output_shape.DebugString()));
258     }
259   }
260 
IsExpensive()261   bool IsExpensive() override { return false; }
262 
263  private:
264   std::unordered_set<int32> squeeze_dims_;
265 };
266 
267 }  // namespace tensorflow
268 
269 #endif  // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
270